Libraries

suppressPackageStartupMessages({
  # Load required packages
  library(rlang)
  library(dplyr)
  library(tidyr)
  library(ggplot2)
  library(viridisLite)
  library(pals)
  library(ComplexHeatmap)
  library(circlize)
  library(dendsort)
  library(ggtree)
  library(cowplot)
  library(ggsignif)
  library(scran)
  library(aricode)
  library(ggthemr)
  library(mclust)
  library(effsize)
  library(broom)
  source("benchmarkUtils.R")
})

minMaxScaler <- function(x) {
  (x-min(x, na.rm=T))/(max(x, na.rm=T)-min(x, na.rm=T))
}
rename <- dplyr::rename
select <- dplyr::select
slice <- dplyr::slice
# Set ggplot theme
ggthemr_reset()
ggthemr('pale')
mycolors <- c("#999999", "#0072B2", "#E69F00", "#F0E442", "#D9B3FF", "#009E73",  
               "#D55E00", "#5D8AA8", "#CC79A7", "#56B4E9",
              "#F3B3A6", "#A5AB81", "#B2182B", "#4393C3", "#CDBE6B", 
              "#80CDC1", "#F4A582", "#BABABA", "#CCEBC5", "#DECBE4",
              "#FDDFDF", "#B3DE69", "#FDBF6F", "#CCECE6", "#FB8072")

dl_method_pal <- c(
  "TCGN" = "#CA9C91", 
  "THItoGene" ="#BA7FB5",
  "EGNv1" = "#B3D46B", 
  "EGNv2" =  "#F7CBDF",
  "DeepPT"="#80B1D2",
  "DeepSpaCE"="#F18072", 
  "ST-Net"="#8CD0C3",
  "HisToGene"="#f1c232",
  "Hist2ST"="#F9B063",
  "GeneCodeR"="#BCB9D8",
  "iStar"= "#D9B3FF",
  "RNA-Seq"="#BABABA",
  "RNA-Seq-STgenes"="#BABABA"
  )

mypalette <- define_palette(
  swatch = mycolors,
  gradient = c(lower = mycolors[1L], upper = mycolors[3L])
)

ggthemr(mypalette) # for some reason it uses the first colour as the colour of the grid lines
theme_update(panel.grid.major = element_line(linetype="dotted"))

th <-   theme(text=element_text(size=16),
                axis.text.x = element_text(angle = 0, hjust = 0.5),
                panel.grid.major = element_blank(),
                panel.grid.minor = element_blank(),
                panel.background = element_rect(colour = "black", size=0.7, fill=NA) )
## Warning: The `size` argument of `element_rect()` is deprecated as of ggplot2 3.4.0.
## ℹ Please use the `linewidth` argument instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.

Load Data

Load data for HER2ST dataset
# Load data
# Get all expression and prediction data
comb_pred_dat <- readRDS("data/processed/her2st/her2st_comb_pred_dat_11.rds") %>%
  mutate(model_id = ifelse(model_id =="genecoder_i500_j500",
                           "GeneCodeR", model_id))

# Get calculated metrics
pred_feat_cor <- readRDS("data/processed/her2st/her2st_pred_feat_cor_11.rds") %>%
  mutate(model_id = case_when(
    model_id == "genecoder_i500_j500" ~ "GeneCodeR",
    TRUE ~ model_id
  ))

# Get original gene expression per patch
deeppt_exprs_df <- read.csv("data/processed/her2st/processed_expr.csv")

# Highly Variable Genes
# use methods from scran (Lun, McCarthy, and Marioni 2016) to identify a set of top highly variable genes (HVGs)
dec_her2st <- modelGeneVar(deeppt_exprs_df %>%
                             select(-X) %>%
                             t())
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
# Get the top 10% HVGs.
hv_genes <- getTopHVGs(dec_her2st, prop=.100)

comb_pred_dat is a dataframe with the following columns:

  • img_id: id of image
  • model_id: the name of method predicting gene expression
  • patch_id: id of image patch, containing image + patch coordinates
  • pred_type: whether the evaluation metrics were calculated on the “train”, “test” or “validation” sets
  • train_fold: an id for the fold the model was trained and tested on
  • gene: name of the gene that was predicted
  • pred: predicted gene expression count
  • exprs: ground truth gene expression count
glimpse(comb_pred_dat)
## Rows: 431,388,115
## Columns: 9
## $ img_id     <chr> "A1", "A1", "A1", "A1", "A1", "A1", "A1", "A1", "A1", "A1",…
## $ model_id   <chr> "HisToGene", "HisToGene", "HisToGene", "HisToGene", "HisToG…
## $ pred_type  <chr> "test", "test", "test", "test", "test", "test", "test", "te…
## $ train_fold <chr> "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1",…
## $ row_id     <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,…
## $ gene       <chr> "HPS6", "TNC", "NR1H2", "NUP93", "HNRNPUL2", "MARS", "SUGP1…
## $ exprs      <dbl> 1.037466, 0.000000, 1.037466, 1.037466, 0.000000, 1.318105,…
## $ pred       <dbl> 0.10249946, 0.35236460, 0.27433290, 0.16560593, 0.67725086,…
## $ patch_id   <chr> "A1_10x13", "A1_10x13", "A1_10x13", "A1_10x13", "A1_10x13",…

pred_feat_cor is a dataframe with the following columns:

  • img_id: id of image
  • model_id: the name of method predicting gene expression
  • gene: name of the gene that was predicted
  • pred_type: whether the evaluation metrics were calculated on the “train”, “test” or “validation” sets
  • train_fold: an id for the fold the model was trained and tested on
  • cor_pearson: Pearson correlation
  • cor_spearman: Spearman correlation
  • var_exprs: variance of the gene expression data the model was trained on
  • var_exprs_orig: variance of the ground truth gene expression counts
  • var_pred: variance of the predicted gene expression
  • mean_exprs: mean of the gene expression data the model was trained on
  • mean_exprs_orig: mean of the ground truth gene expression counts
  • mean_pred: mean of the predicted gene expression
  • rmse: root mean squared error
  • mi: mutual information
  • js_div: Jensen-Shannon divergence
  • nrmse_range: RMSE normalised using the range of the data
  • nrmse_sd: RMSE normalised using the standard deviation of data
  • ssim: structural similarity index
  • auc_n: original ground truth gene expression is binarised by \(GE <= n\) and \(GE > n\). auc_n represents the auc of the predicted gene expression distinguishing the binarisation of the original gene expression
glimpse(pred_feat_cor)
## Rows: 1,252,075
## Columns: 26
## $ gene            <chr> "A2M", "A2M", "A2M", "A2M", "A2M", "A2M", "A2M", "A2M"…
## $ pred_type       <chr> "test", "test", "test", "test", "test", "test", "test"…
## $ train_fold      <chr> "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1",…
## $ model_id        <chr> "DeepPT", "DeepPT", "DeepPT", "DeepPT", "DeepPT", "Dee…
## $ img_id          <chr> "A1", "A2", "A3", "A4", "A5", "A6", "H1", "H2", "H3", …
## $ cor_pearson     <dbl> 0.1427815292, 0.0377918044, -0.0007100285, 0.110365113…
## $ cor_spearman    <dbl> 0.10018397, 0.02585776, 0.02662824, 0.16275219, 0.1073…
## $ var_exprs       <dbl> 2.48163693, 13.44934473, 6.98599462, 11.93492234, 8.19…
## $ var_exprs_orig  <dbl> 2.481637, 13.449345, 6.985995, 11.934922, 8.195792, 9.…
## $ var_pred        <dbl> 7.176909616, 6.396816870, 6.668901264, 11.005013152, 8…
## $ mean_exprs      <dbl> 1.0491329, 2.9138462, 2.4122563, 2.5043732, 2.4759036,…
## $ mean_exprs_orig <dbl> 1.049133, 2.913846, 2.412256, 2.504373, 2.475904, 2.11…
## $ mean_pred       <dbl> 4.6499880, 4.0803390, 4.0179606, 4.4986984, 4.2622191,…
## $ rmse            <dbl> 4.6254865, 4.5218239, 4.0255201, 4.9323499, 4.2498333,…
## $ mi              <dbl> 0.03551507, 0.02940891, 0.03831184, 0.03667822, 0.0454…
## $ js_div          <dbl> 0.3740230, 0.2863111, 0.2739019, 0.3136536, 0.2794005,…
## $ nrmse_range     <dbl> 0.4625486, 0.1458653, 0.2683680, 0.1972940, 0.2499902,…
## $ nrmse_sd        <dbl> 2.9362180, 1.2330000, 1.5230280, 1.4277234, 1.4844871,…
## $ ssim            <dbl> 0.097091520, 0.035521460, 0.008964697, 0.087410856, 0.…
## $ auc_0           <dbl> 0.5678295, 0.5037837, 0.5285893, 0.5776532, 0.5640391,…
## $ auc_1           <dbl> 0.5282170, 0.5181714, 0.5046305, 0.5917741, 0.5450191,…
## $ auc_2           <dbl> 0.5167220, 0.5004427, 0.5016291, 0.6033259, 0.5625710,…
## $ auc_5           <dbl> 0.5440476, 0.5073972, 0.5243597, 0.5297674, 0.5374114,…
## $ auc_7           <dbl> 0.4532164, 0.5328498, 0.4780186, 0.5421700, 0.5942428,…
## $ auc_10          <dbl> NA, 0.57219251, 0.44759207, 0.57237237, 0.68209877, 0.…
## $ auc_20          <dbl> NA, 0.9382716, NA, 0.5073314, NA, 0.1337047, NA, NA, N…
Load data for CSCC dataset
#prediction metrics
pred_feat_cor_cscc <- readRDS("data/processed/cscc/cscc_pred_feat_cor_11.rds") %>%
  mutate(model_id = case_when(
    model_id == "genecoder_i500_j500" ~ "GeneCodeR",
    TRUE ~ model_id
  )) 

# Get original gene expression per patch
deeppt_exprs_df_cscc <- read.csv("data/processed/cscc/processed_expr.csv")

# Highly Variable Genes
dec_cscc <- modelGeneVar(deeppt_exprs_df_cscc %>%
                             select(-c(X)) %>%
                             t())
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
# Get the top 10% of genes.
hv_genes_cscc <- getTopHVGs(dec_cscc, prop=.100)
Load data for HER2+ trained models used for Visium-HER2+ prediction
pred_feat_cor_her2_bc <- readRDS("data/processed/visium/visium_her2_bc_pred_feat_cor_10.rds") %>%
  mutate(model_id = case_when(
    model_id == "genecoder_i500_j500" ~ "GeneCodeR",
    TRUE ~ model_id
  )) 
Load data for Visium-Hercep-Test2+ (used for training) and Visium-HER2+ (used for testing)
#prediction results of using 990 HVGs for training and testing
pred_feat_cor_whole_bc_990 <- readRDS("data/processed/visium/visium_whole_bc_pred_feat_cor_990_genes_6.rds") %>%
  mutate(model_id = case_when(
    model_id == "genecoder_i500_j500" ~ "GeneCodeR",
    TRUE ~ model_id
  )) 

#prediction results of using 274 High-sparsity genes (HSGs) for training and testing
pred_feat_cor_whole_bc_274 <- readRDS("data/processed/visium/visium_whole_bc_pred_feat_cor_274_genes_7.rds") %>%
  mutate(model_id = case_when(
    model_id == "genecoder_i500_j500" ~ "GeneCodeR",
    TRUE ~ model_id
  )) 

# Get original gene expression per patch
deeppt_exprs_df_whole_bc<- read.csv("data/processed/visium/processed_visium_bc_expr_990.csv")

# Highly Variable Genes
dec_whole_bc<- modelGeneVar(deeppt_exprs_df_whole_bc %>%
                             select(-c(X)) %>%
                             t())
# Get the top 10% of genes.
hv_genes_whole_bc <- getTopHVGs(dec_whole_bc, prop=.100)
Load data for Visium-Kidney dataset
#prediction results of using 992 HVGs for training and testing
pred_feat_cor_kidney_992 <- readRDS("data/processed/visium/visium_kidney_pred_feat_cor_992_genes_6.rds") %>%
  mutate(model_id = case_when(
    model_id == "genecoder_i500_j500" ~ "GeneCodeR",
    TRUE ~ model_id
  )) 

#prediction results of using 145 High-sparsity genes (HSGs) for training and testing
pred_feat_cor_kidney_145 <- readRDS("data/processed/visium/visium_kidney_pred_feat_cor_145_genes_8.rds") %>%
  mutate(model_id = case_when(
    model_id == "genecoder_i500_j500" ~ "GeneCodeR",
    TRUE ~ model_id
  )) 

# Get original gene expression per patch
deeppt_exprs_df_kidney<- read.csv("data/processed/visium/processed_expr_kidney.csv")

# Highly Variable Genes
dec_kidney<- modelGeneVar(deeppt_exprs_df_kidney %>%
                             select(-c(X)) %>%
                             t())
# Get the top 10% of genes.
hv_genes_kidney <- getTopHVGs(dec_kidney, prop=.100)
Load SVGs
#HER2+
svg_genes <- readRDS("data/processed/her2st/her2st_SVG_filter_top20.rds")
#cSCC
svg_genes_cscc <- readRDS("data/processed/cscc/cscc_SVG_filter_top20.rds")
#Visium-HER2+
svg_bc <- readRDS("data/processed/visium/visium_bc_SVG_filter_top20.rds")
#Visium-Kidney
svg_kidney <- readRDS("data/processed/visium/kidney_SVG_filter_top20_whole_image.rds")

Image - Predicted vs Truth - Figure 2c

# Plot predicted vs truth for top genes for one image
filt_gene_df <- comb_pred_dat %>%
  filter(pred_type == "test") %>%
  filter(img_id == "B1" &
           gene %in% "FASN") %>%
  separate("patch_id",
           c("img_id", "patch_coord"),
           sep = "_",
           remove = FALSE) %>%
  separate("patch_coord", c("x", "y"), sep = "x")

filt_gene_df <- filt_gene_df %>%
  bind_rows(
    filt_gene_df %>%
      filter(model_id == "DeepPT") %>%
      mutate(pred = exprs) %>%
      mutate(model_id = "Ground Truth")
  ) %>%
  mutate_at(c("x", "y"), as.numeric) %>%
  group_by(model_id, gene) %>%
  mutate(pred = minMaxScaler(pred))

fig2_exprs <- filt_gene_df %>%
  mutate(model_id = factor(
    model_id,
    levels = c(
      "Ground Truth","ST-Net", "HisToGene", "GeneCodeR", "DeepSpaCE", "DeepPT", "Hist2ST", "EGNv1", "EGNv2", "TCGN", "THItoGene", "iStar"
    )
  )) %>%
  ggplot() +
  aes(x = x, y = -y, col = pred) +
  geom_point(size=1.2) +
  facet_wrap(~ model_id, scales = "free",
             ncol = 12) +
  scale_color_viridis_c() +
  labs(col = "Scaled Expression",x="",y="") +
  theme(
    axis.ticks = element_blank(),
    axis.text = element_blank(),
    axis.line = element_blank(),
    panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
    legend.text = element_text(size=10),
    legend.key.size = unit(15, 'points'),
    legend.title = element_text(size=12),
    strip.text.x = element_text(size=12)
  )

fig2_exprs

Figure 2d

Boxplot/violin plots of the average PCC, MI, SSIM and AUC between the ground truth gene expression and predicted gene expression. Metrics measured from the test fold of a 4-fold CV, averaged over each gene across HER2+ and CSCC ST datasets.

#' Plot Boxplots of Correlations
pred_gene_msr_df <- pred_feat_cor %>%
  mutate(dataset = "her2+") %>%
  bind_rows(
    pred_feat_cor_cscc %>%
      mutate(dataset="cscc")
  ) %>%
  filter(pred_type == "test") %>%
  select(gene, pred_type, train_fold, model_id,dataset,
         img_id, cor_pearson, mi, ssim, auc_0) %>%
  # Average images over each patient
  group_by(gene, pred_type, train_fold, model_id, dataset) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # # Average all folds and patients over each gene
  group_by(pred_type, gene, model_id, dataset) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T))

th <-   theme(text=element_text(size=16),
                axis.text.x = element_text(angle = 0, hjust = 0.5),
                panel.grid.major = element_blank(),
                panel.grid.minor = element_blank(),
                panel.background = element_rect(colour = "black", size=0.7, fill=NA) )

pred_gene_msr_df <- pred_gene_msr_df%>%
  arrange(factor(model_id, levels=c("iStar","THItoGene", "TCGN", "EGNv2", "EGNv1", "Hist2ST", "DeepPT", "DeepSpaCE", "GeneCodeR", "HisToGene", "ST-Net")))
pred_gene_msr_df$model_id <- factor(pred_gene_msr_df$model_id, levels = unique(pred_gene_msr_df$model_id))

fig_2d <- pred_gene_msr_df %>%
  pivot_longer(cols=c("cor_pearson","mi", "ssim","auc_0"),
               names_to="metric",
               values_to="value") %>%
  mutate(metric = factor(metric, levels = c("cor_pearson","mi",
                                            "js_div","nrmse_sd","ssim", "auc_0"))) %>%
  mutate(metric=case_when(
    metric=="cor_pearson" ~ "PCC",
    metric=="mi" ~ "MI",
    metric=="ssim" ~ "SSIM",
    metric=="auc_0"~"AUC",
    TRUE ~ "other"
  ) %>%
    factor(., levels=c("PCC","MI","SSIM","AUC"))) %>%
  ggplot() +
  aes(x=model_id, y=value, fill=model_id, col=model_id) +
  geom_violin(alpha=0.2) +
  geom_boxplot(alpha=0.5, width=0.3) +
  facet_wrap(~metric, nrow=2, scales="free") +
  coord_flip()+
  scale_color_manual(values=dl_method_pal)+
  scale_fill_manual(values=dl_method_pal)+
  theme(legend.position = "none") +
  th + 
  labs(x="", y="")

fig_2d

Supplementary Figure 1

Violin and boxplots of evaluation metrics for gene expression for each method in the HER2+ ST dataset

# Plot Boxplots of Correlations
pred_gene_msr_df <- pred_feat_cor %>%
  mutate(gene_set = "All Genes") %>%
  bind_rows(
    pred_feat_cor %>%
      ungroup() %>%
      filter(gene %in% hv_genes) %>%
      mutate(gene_set = "HVGs")
  ) %>%
  bind_rows(
    pred_feat_cor %>%
      ungroup() %>%
      semi_join(svg_genes, by = c("gene", "img_id")) %>%
      mutate(gene_set = "SVGs")
  ) %>%
  select(gene, pred_type, train_fold, model_id,
         img_id, cor_pearson, nrmse_sd, js_div, mi, ssim, auc_0,
         gene_set) %>%
  # Average image over each patient
  group_by(gene, pred_type, train_fold, gene_set, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # Average all folds and patients over each gene
  group_by(pred_type, gene, gene_set, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T))

pred_gene_msr_df %>%
  filter(pred_type == "test") %>%
  group_by(model_id) %>%
  summarise_at(vars(c(cor_pearson, mi, ssim, auc_0)), mean)
## # A tibble: 11 × 5
##    model_id  cor_pearson     mi   ssim auc_0
##    <chr>           <dbl>  <dbl>  <dbl> <dbl>
##  1 DeepPT         0.119  0.0462 0.0776 0.570
##  2 DeepSpaCE      0.0630 0.0375 0.0389 0.581
##  3 EGNv1          0.0475 0.0363 0.0376 0.526
##  4 EGNv2          0.176  0.0541 0.137  0.613
##  5 GeneCodeR      0.0896 0.0386 0.0406 0.560
##  6 HisToGene      0.0926 0.0552 0.0603 0.592
##  7 Hist2ST        0.102  0.0622 0.0642 0.607
##  8 ST-Net         0.140  0.0534 0.0875 0.592
##  9 TCGN           0.0965 0.0529 0.0644 0.587
## 10 THItoGene      0.0739 0.0567 0.0470 0.578
## 11 iStar          0.0544 0.0336 0.0667 0.552
pred_gene_msr_df <- pred_gene_msr_df%>%
  arrange(factor(model_id, levels=c("ST-Net", "HisToGene", "GeneCodeR", "DeepSpaCE", "DeepPT", "Hist2ST", "EGNv1", "EGNv2", "TCGN", "THItoGene")))
pred_gene_msr_df$model_id <- factor(pred_gene_msr_df$model_id, levels = unique(pred_gene_msr_df$model_id))

#boxplots of metrics
p_all_boxplots_her2 <- pred_gene_msr_df %>%
  filter(pred_type == "test") %>%
  pivot_longer(cols=c("cor_pearson", "nrmse_sd", "js_div", "mi", "ssim","auc_0"),
               names_to="metric",
               values_to="value") %>%
  mutate(metric = factor(metric, levels = c("cor_pearson","mi",
                                            "js_div","nrmse_sd","ssim", "auc_0"))) %>%
  ggplot() +
  aes(x=model_id, y=value, fill=model_id, col=model_id) +
  geom_violin(alpha=0.2) +
  geom_boxplot(alpha=0.5, width=0.3) +
  facet_wrap(~gene_set + metric,
             nrow=3, scales="free") +
  scale_color_manual(values=dl_method_pal)+
  scale_fill_manual(values=dl_method_pal)+
  theme(axis.text.x = element_text(angle=45, hjust=1, vjust=1, size=7),
        legend.position = "bottom", 
        legend.title = element_text(size = 16), 
        legend.text = element_text(size = 10),
        legend.key.size = unit(0.4, "cm"),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        panel.background = element_rect(colour = "black", size=0.7, fill=NA) ) +
  labs(title="Boxplot of Metrics (HER2+ ST)", x="",y="",
       col="",fill="")

p_all_boxplots_her2
## Warning: Removed 63 rows containing non-finite outside the scale range
## (`stat_ydensity()`).
## Warning: Removed 63 rows containing non-finite outside the scale range
## (`stat_boxplot()`).

Supplementary Figure 2

Violin and boxplots of evaluation metrics for gene expression for each method in the CSCC ST dataset

# Plot Boxplots of Correlations
pred_gene_msr_df_cscc <- pred_feat_cor_cscc %>%
  mutate(gene_set = "All Genes") %>%
  bind_rows(
    pred_feat_cor_cscc %>%
      ungroup() %>%
      filter(gene %in% hv_genes_cscc) %>%
      mutate(gene_set = "HVGs")
  ) %>%
  bind_rows(
    pred_feat_cor_cscc %>%
      ungroup() %>%
      semi_join(svg_genes_cscc, by = c("gene", "img_id")) %>%
      mutate(gene_set = "SVGs")
  ) %>%
  select(gene, pred_type, train_fold, model_id,
         img_id, cor_pearson, nrmse_sd, js_div, mi, ssim, auc_0,
         gene_set) %>%
  # Average image over each patient
  group_by(gene, pred_type, train_fold, gene_set, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # Average all folds and patients over each gene
  group_by(pred_type, gene, gene_set, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T))

pred_gene_msr_df_cscc <- pred_gene_msr_df_cscc%>%
  arrange(factor(model_id, levels=c("ST-Net", "HisToGene", "GeneCodeR", "DeepSpaCE", "DeepPT", "Hist2ST", "EGNv1", "EGNv2", "TCGN", "THItoGene")))
pred_gene_msr_df_cscc$model_id <- factor(pred_gene_msr_df_cscc$model_id, levels = unique(pred_gene_msr_df_cscc$model_id))

#boxplots of metrics
p_all_boxplots_cscc <- pred_gene_msr_df_cscc %>%
  filter(pred_type == "test") %>%
  pivot_longer(cols=c("cor_pearson", "nrmse_sd", "js_div", "mi", "ssim","auc_0"),
               names_to="metric",
               values_to="value") %>%
  mutate(metric = factor(metric, levels = c("cor_pearson","mi",
                                            "js_div","nrmse_sd","ssim", "auc_0"))) %>%
  ggplot() +
  aes(x=model_id, y=value, fill=model_id, col=model_id) +
  geom_violin(alpha=0.2) +
  geom_boxplot(alpha=0.5, width=0.3) +
  facet_wrap(~gene_set + metric,
             nrow=3, scales="free") +
  scale_color_manual(values=dl_method_pal)+
  scale_fill_manual(values=dl_method_pal)+
  theme(axis.text.x = element_text(angle=45, hjust=1, vjust=1, size=7),
        legend.position = "bottom",
        legend.title = element_text(size = 16), 
        legend.text = element_text(size = 10),
        legend.key.size = unit(0.4, "cm"),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        panel.background = element_rect(colour = "black", size=0.7, fill=NA) ) +
  labs(title="Boxplot of Metrics (cSCC ST)", x="",y="",
       col="",fill="")

p_all_boxplots_cscc
## Warning: Removed 6 rows containing non-finite outside the scale range
## (`stat_ydensity()`).
## Warning: Removed 6 rows containing non-finite outside the scale range
## (`stat_boxplot()`).

Supplementary Figure 6 - Visium-Kidney

Violin and boxplots of evaluation metrics for gene expression for each method in the Visium-Kidney dataset

# Plot Boxplots of Correlations
pred_gene_msr_df_kidney <- pred_feat_cor_kidney_992 %>%
  mutate(gene_set = "All Genes") %>%
  bind_rows(
    pred_feat_cor_kidney_992 %>%
      ungroup() %>%
      filter(gene %in% hv_genes_kidney) %>%
      mutate(gene_set = "HVGs")
  ) %>%
  bind_rows(
    pred_feat_cor_kidney_992 %>%
      ungroup() %>%
      semi_join(svg_kidney, by = c("gene")) %>%
      mutate(gene_set = "SVGs")
  ) %>%
  bind_rows(
    pred_feat_cor_kidney_992 %>%
      ungroup() %>%
      semi_join(pred_feat_cor_kidney_145, by = c("gene", "img_id")) %>%
      mutate(gene_set = "HSGs")
  ) %>%
  select(gene, model_id, train_fold,
         img_id, cor_pearson, nrmse_sd, js_div, mi, ssim, auc_0,
         gene_set) %>%
  # Average images over each patient
  group_by(gene, train_fold, gene_set, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # Average all folds and patients over each gene
  group_by(gene, gene_set, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T))

pred_gene_msr_df_kidney <- pred_gene_msr_df_kidney%>%
  arrange(factor(model_id, levels=c("ST-Net", "HisToGene", "GeneCodeR", "DeepSpaCE", "DeepPT", "Hist2ST", "EGNv1", "EGNv2", "TCGN", "THItoGene")))
pred_gene_msr_df_kidney$model_id <- factor(pred_gene_msr_df_kidney$model_id, levels = unique(pred_gene_msr_df_kidney$model_id))

#boxplots of metrics
p_all_boxplots_kidney <- pred_gene_msr_df_kidney %>%
  #filter(pred_type == "test") %>%
  pivot_longer(cols=c("cor_pearson", "nrmse_sd", "js_div", "mi", "ssim","auc_0"),
               names_to="metric",
               values_to="value") %>%
  mutate(metric = factor(metric, levels = c("cor_pearson","mi",
                                            "js_div","nrmse_sd","ssim", "auc_0"))) %>%
  ggplot() +
  aes(x=model_id, y=value, fill=model_id, col=model_id) +
  geom_violin(alpha=0.2) +
  geom_boxplot(alpha=0.5, width=0.3) +
  facet_wrap(~gene_set + metric,
             nrow=4, scales="free") +
  scale_color_manual(values=dl_method_pal)+
  scale_fill_manual(values=dl_method_pal)+
  theme(axis.text.x = element_text(angle=45, hjust=1, vjust=1, size=7),
        legend.position = "bottom", 
        legend.title = element_text(size = 16), 
        legend.text = element_text(size = 10),
        legend.key.size = unit(0.4, "cm"),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        panel.background = element_rect(colour = "black", size=0.7, fill=NA) ) +
  labs(title="Boxplot of Metrics (Visium Kidney)", x="",y="",
       col="",fill="")

p_all_boxplots_kidney

Supplementary Figure 12 - Visium BC (Visium-Hercep-Test2+ for training and Visium-HER2+ for testing)

Violin and boxplots of evaluation metrics for gene expression for each method in the Visium-HER2+ dataset

#' Plot Boxplots of Correlations
pred_gene_msr_df_whole_bc <- pred_feat_cor_whole_bc_990 %>%
  mutate(gene_set = "All Genes") %>%
  bind_rows(
    pred_feat_cor_whole_bc_990 %>%
      ungroup() %>%
      filter(gene %in% hv_genes_whole_bc) %>%
      mutate(gene_set = "HVGs")
  ) %>%
  bind_rows(
    pred_feat_cor_whole_bc_990 %>%
      ungroup() %>%
      semi_join(svg_bc, by = c("gene", "img_id")) %>%
      mutate(gene_set = "SVGs")
  ) %>%
  bind_rows(
    pred_feat_cor_whole_bc_990 %>%
      ungroup() %>%
      semi_join(pred_feat_cor_whole_bc_274, by = c("gene", "img_id")) %>%
      mutate(gene_set = "HSGs")
  ) %>%
  select(gene, train_fold, model_id,
         img_id, cor_pearson, nrmse_sd, js_div, mi, ssim, auc_0,
         gene_set) %>%
  # Average image over each patient
  group_by(gene, train_fold, gene_set, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # Average all folds and patients over each gene
  group_by(gene, gene_set, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T))

pred_gene_msr_df_whole_bc %>%
  group_by(model_id) %>%
  summarise_at(vars(c(cor_pearson, mi, ssim, auc_0)), mean)
## # A tibble: 6 × 5
##   model_id  cor_pearson      mi     ssim auc_0
##   <chr>           <dbl>   <dbl>    <dbl> <dbl>
## 1 DeepPT        0.0129  0.00815  0.0159    NaN
## 2 EGNv1         0.0734  0.00942  0.0695    NaN
## 3 GeneCodeR    -0.0169  0.00708  0.00393   NaN
## 4 ST-Net        0.00669 0.00999  0.0134    NaN
## 5 TCGN          0.0602  0.0186   0.0540    NaN
## 6 THItoGene    -0.0252  0.0105  -0.00925   NaN
pred_gene_msr_df_whole_bc <- pred_gene_msr_df_whole_bc%>%
  arrange(factor(model_id, levels=c("ST-Net", "HisToGene", "GeneCodeR", "DeepSpaCE", "DeepPT", "Hist2ST", "EGNv1", "EGNv2", "TCGN", "THItoGene")))
pred_gene_msr_df_whole_bc$model_id <- factor(pred_gene_msr_df_whole_bc$model_id, levels = unique(pred_gene_msr_df_whole_bc$model_id))

#boxplots of metrics
p_all_boxplots_whole_bc <- pred_gene_msr_df_whole_bc %>%
  #filter(pred_type == "test") %>%
  pivot_longer(cols=c("cor_pearson", "nrmse_sd", "js_div", "mi", "ssim","auc_0"),
               names_to="metric",
               values_to="value") %>%
  mutate(metric = factor(metric, levels = c("cor_pearson","mi",
                                            "js_div","nrmse_sd","ssim", "auc_0"))) %>%
  ggplot() +
  aes(x=model_id, y=value, fill=model_id, col=model_id) +
  geom_violin(alpha=0.2) +
  geom_boxplot(alpha=0.5, width=0.3) +
  facet_wrap(~gene_set + metric,
             nrow=4, scales="free") +
  scale_color_manual(values=dl_method_pal)+
  scale_fill_manual(values=dl_method_pal)+
  theme(axis.text.x = element_text(angle=45, hjust=1, vjust=1, size=7),
        legend.position = "bottom", 
        legend.title = element_text(size = 16), 
        legend.text = element_text(size = 10),
        legend.key.size = unit(0.4, "cm"),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        panel.background = element_rect(colour = "black", size=0.7, fill=NA) ) +
  labs(title="Boxplot of Metrics (Visium-HER2+)", x="",y="",
       col="",fill="")

p_all_boxplots_whole_bc
## Warning: Removed 34 rows containing non-finite outside the scale range
## (`stat_ydensity()`).
## Warning: Removed 34 rows containing non-finite outside the scale range
## (`stat_boxplot()`).

Figure 3a

PCC and SSIM violin and boxplots for each method in HER2+ dataset for all genes as well as for HVGs only. Significance between HVGs and all genes are calculated using Wilcoxon rank-sum test

pred_cor_hvg_df <- pred_gene_msr_df %>%
  filter(pred_type == "test") %>%
  pivot_longer(cols=c("cor_pearson","ssim"),
               names_to="metric",
               values_to="value") %>%
  mutate(metric = factor(metric, levels = c("cor_pearson","ssim"))) %>%
  mutate(metric=ifelse(metric=="ssim","SSIM",
                       ifelse(metric=="cor_pearson","PCC",metric)))

sigFunc <- function(test_result, group) {
  if (test_result$p.value < 0.001) {
    signif <- "***"
  } else if (test_result$p.value < 0.01) {
    signif <- "**"
  } else if (test_result$p.value < 0.05) {
    signif <- "*"
  } else if (test_result$p.value < 0.1) {
    signif <- "\u00b0"
  } else {
    signif <- NA
  }

  stat_report <- paste0(
    "italic(p[", group, "]) == ", sprintf("%.2g", test_result$p.value)
  )

  list(signif = signif, stat_report = stat_report)
}

ptest_HVG <- pred_cor_hvg_df %>%
  filter(gene_set %in% c("All Genes", "HVGs")) %>%
  group_by(pred_type, model_id, metric) %>%
  do({
    test_result <- wilcox.test(value ~ gene_set, data = ., alternative = "less")
    sig <- sigFunc(test_result, group = "hvg")
    data.frame(
      signif = sig$signif,
      stat_report = sig$stat_report,
      xmin = "All Genes",
      xmax = "HVGs",
      max_val = max(.$value)
    )
  })

ptest_SVG <- pred_cor_hvg_df %>%
  filter(gene_set %in% c("All Genes", "SVGs")) %>%
  group_by(pred_type, model_id, metric) %>%
  do({
    test_result <- wilcox.test(value ~ gene_set, data = ., alternative = "less")
    sig <- sigFunc(test_result, group = "svg")
    data.frame(
      signif = sig$signif,
      stat_report = sig$stat_report,
      xmin = "All Genes",
      xmax = "SVGs",
      max_val = max(.$value)
    )
  })

p_hvg_cor_ssim <- pred_cor_hvg_df %>%
  ggplot() +
  aes(x = gene_set,
      y = value,
      fill = gene_set,
      col = gene_set) +
  geom_violin(alpha = 0.2,
              width = 1,
              position = position_dodge(width = 0.6)) +
  geom_boxplot(alpha = 0.5,
               width = 0.3,
               position = position_dodge(width = 0.6)) +
  scale_color_manual(values = c("All Genes" = "#4393C3", "HVGs" = "#009E73", "SVGs" = "#f6b000")) +
  scale_fill_manual(values = c("All Genes" = "#4393C3", "HVGs" = "#009E73", "SVGs" = "#f6b000")) +
  
  geom_signif(
    aes(
      xmin = xmin,
      xmax = xmax,
      y_position = max_val + max_val / 10,
      annotations = signif
    ),
    color = "grey70",
    textsize = 5.5,
    vjust = 1.5,
    margin_top = 0.15,
    data = ptest_HVG,
    manual = TRUE,
    inherit.aes = FALSE
  ) +
  geom_signif(
    aes(
      xmin = xmin,
      xmax = xmax,
      y_position = max_val + max_val / 4.5,
      annotations = signif
    ),
    color = "grey40",
    textsize = 5.5,
    vjust = 1.5,
    margin_top = 0.15,
    data = ptest_SVG,
    manual = TRUE,
    inherit.aes = FALSE
  ) +
  geom_text(
    aes(
      x = as.numeric(factor(xmin)) + 0.9,
      y = -0.13,
      label = stat_report
    ),
    data = ptest_HVG,
    inherit.aes = FALSE,
    size = 2.3, 
    parse = TRUE,
    color = "grey20" 
  ) +
  geom_text(
    aes(
      x = as.numeric(factor(xmax)) + 0.9, 
      y = -0.17, 
      label = stat_report
    ),
    data = ptest_SVG,
    inherit.aes = FALSE,
    size = 2.3,
    parse = TRUE,
    color = "grey20"
  ) +
  facet_grid(cols = vars(model_id), rows = vars(metric),
             switch = "y") +
  theme(
    legend.position = "bottom",
    strip.placement = "outside",
    axis.ticks.x = element_blank(),
    axis.text.x = element_blank(),
    strip.text.x = element_text(size = 14),
    strip.text.y = element_text(size = 15),
    panel.grid.major = element_blank(),
    panel.grid.minor = element_blank(),
  ) +
  labs(x = "", y = "",
       col = "Gene Set", fill = "Gene Set")
## Warning in geom_signif(aes(xmin = xmin, xmax = xmax, y_position = max_val + :
## Ignoring unknown aesthetics: xmin, xmax, y_position, and annotations
## Warning in geom_signif(aes(xmin = xmin, xmax = xmax, y_position = max_val + :
## Ignoring unknown aesthetics: xmin, xmax, y_position, and annotations
p_hvg_cor_ssim
## Warning: Removed 24 rows containing missing values or values outside the scale range
## (`geom_signif()`).
## Warning: Removed 3 rows containing missing values or values outside the scale range
## (`geom_signif()`).

Correlation - Results for top genes

Top five correlated genes in the HER2+ ST and CSCC dataset

# Get top genes in her2+
top_cor_df <- pred_feat_cor %>%
  filter(pred_type=="test") %>% 
  mutate(dataset="her2+") %>%
  group_by(gene, model_id, dataset) %>%
  summarise(mean_cor = mean(cor_pearson, na.rm=T),.groups="drop") %>%
  pivot_wider(names_from = "model_id", values_from="mean_cor") %>%
  rowwise() %>% 
  mutate(overall_mean_cor = mean(c_across(-c(gene, dataset)))) %>%
  arrange(dataset,-overall_mean_cor) %>%
  ungroup()

# Get top genes in cscc
top_cor_df_cscc <- pred_feat_cor_cscc %>%
  filter(pred_type=="test") %>% 
  mutate(dataset="cscc") %>%
  group_by(gene, model_id, dataset) %>%
  summarise(mean_cor = mean(cor_pearson, na.rm=T),.groups="drop") %>%
  pivot_wider(names_from = "model_id", values_from="mean_cor") %>%
  rowwise() %>% 
  mutate(overall_mean_cor = mean(c_across(-c(gene, dataset)))) %>%
  arrange(dataset,-overall_mean_cor) %>%
  ungroup()

Figure 3b

# Plot top genes her2+
plotTopGenesBar <- function(top_cor_df) {
  top_cor_df %>%
    group_by(dataset) %>%
    dplyr::slice(1:5) %>%
    mutate(gene = factor(gene, levels=rev(gene))) %>%
    pivot_longer(cols=!c("gene", "overall_mean_cor", "dataset"),
                 names_to="model_id",
                 values_to="cor")  %>% 
    ggplot() +
    aes(x=gene,y=cor,col=model_id,fill=model_id)+
    geom_bar(stat="identity", position=position_dodge(),
             alpha=0.6, width=0.75) +
    geom_text(data=. %>%
                distinct(gene, overall_mean_cor),
              aes(x=gene, 
                  label = round(overall_mean_cor,2), 
                  y=Inf),
              size = 5, 
              hjust=1.1,
              inherit.aes = F)+
    facet_wrap(~dataset, scales="free")+
    scale_y_continuous(expand = c(0, Inf)) +
    scale_color_manual(values=dl_method_pal)+
    scale_fill_manual(values=dl_method_pal)+
    coord_flip()+
    labs(y="Mean PCC", x="")
}

th <- theme(text=element_text(size=14),
                axis.text.x = element_text(size = 13, angle = 0, hjust = 0),
                axis.text.y = element_text(size = 13, angle = 0, hjust = 0),
                strip.text = element_text(size = 16),
                panel.grid.major = element_blank(),
                panel.grid.minor = element_blank(),
                panel.background = element_rect(colour = "black", size=0.7, fill=NA) )

p_topgenes_her2 <- plotTopGenesBar(top_cor_df %>%
                                     mutate(dataset = ifelse(dataset=="her2+", "HER2+ ST", dataset))) +
                   th

p_topgenes_her2

Figure 3c

p_topgenes_cscc <- plotTopGenesBar(top_cor_df_cscc %>%
                                     mutate(dataset = ifelse(dataset=="cscc", "cSCC ST", dataset))) + 
  theme(plot.title = element_text(size = 14, face = "bold"),
        legend.position="bottom",
        legend.text = element_text(size=14),
    legend.key.size = unit(14, 'points'),
    legend.title = element_text(size=13),
    strip.text.x = element_text(size=14)) +
  labs(col="",fill="") +
  th

p_topgenes_cscc

AUC vs. threshold - Figure 3d

Plot of the AUC (averaged over each gene) of the predicted gene expression distinguishing a binarisation of the ground truth gene expression. Ground truth is binarised according to whether the value was greater than or equal to several thresholds (x-axis).

#' Plot AUC 
pred_n_auc <- pred_feat_cor %>%
  mutate(gene_set = "All Genes") %>%
  bind_rows(
    pred_feat_cor %>%
      ungroup() %>%
      filter(gene %in% hv_genes) %>%
      mutate(gene_set = "HVGs")
  ) %>%
  select(-starts_with(c("cor", "var", "mean","nrmse")), -c("js_div", "mi", "rmse")) %>%
  pivot_longer(cols=starts_with("auc"),
               names_to="metric", values_to="auc") %>%
  mutate(gene_count_threshold = gsub("auc_", "", metric)) %>%
  mutate(pat_id = substr(img_id, 1,1)) %>%
  # Average over each image
  group_by(pat_id, gene, pred_type, train_fold, gene_set, model_id, gene_count_threshold) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # Average over all genes
  group_by(pred_type, train_fold, gene_set, model_id, gene_count_threshold) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # Average the over each fold
  group_by(pred_type, gene_set, model_id, gene_count_threshold) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T))

p_auc_thresh <- pred_n_auc %>% 
  filter(gene_set == "All Genes") %>%
  mutate(gene_count_threshold = as.numeric(gene_count_threshold)) %>% 
  arrange(pred_type, gene_set, model_id, gene_count_threshold) %>%
  filter(pred_type == "test") %>%
  ggplot() +
  aes(x=gene_count_threshold, y=auc, col=model_id) +
  geom_point() +
  geom_path() +
  th + 
  labs(y="Mean AUC", x="Gene Count Threshold",
       col="Method")+
  scale_color_manual(values=dl_method_pal)

p_auc_thresh

Heatmap of gene correlations

hm_dat <- pred_feat_cor %>%
  filter(pred_type == "test") %>% 
  group_by(gene, pred_type, model_id) %>%
  summarise(mean_cor_pearson = mean(cor_pearson, na.rm=T)) %>%
  pivot_wider(names_from="gene",
              values_from="mean_cor_pearson") %>%
  ungroup()
## `summarise()` has grouped output by 'gene', 'pred_type'. You can override using
## the `.groups` argument.
hm_mat <- hm_dat %>%
  select(-c(pred_type, model_id)) %>%
  as.matrix()

rownames(hm_mat) = hm_dat$model_id

col_fun = colorRamp2(
  c(min(hm_mat, na.rm=T), 0.1, max(hm_mat, na.rm=T)), 
  c("#66a6cc", "#fff8de", "#cc2010"))

ha_dat <- data.frame(
      mean_exprs = colMeans(deeppt_exprs_df %>% select(-X)),
      gene = colnames(deeppt_exprs_df %>% select(-X)),
      var_exprs = apply(deeppt_exprs_df %>% select(-X),2,var),
      sd_exprs = apply(deeppt_exprs_df %>% select(-X),2,sd)
    ) %>%
  mutate(hvg = gene %in% hv_genes,
         cv_exprs = sd_exprs/mean_exprs*100) %>%
  arrange(gene)

row_dend = dendsort(hclust(dist(hm_mat)), type="average")
col_dend = dendsort(hclust(dist(t(hm_mat))), type="average")

Check Variance of original & Normalisation - Supplementary Figure 5b

Scatterplot of gene expression variance before (x-axis) and after normalisation (y-axis) for each method.

#' Plot Boxplots of Correlations
pred_gene_msr_df <- pred_feat_cor %>%
  mutate(gene_set = "All Genes") %>%
  select(gene, pred_type, train_fold, model_id,
         img_id, var_exprs, var_exprs_orig, mean_exprs, mean_exprs_orig,
         cor_pearson,ssim,
         gene_set) %>%
  # Average image over each patient
  group_by(gene, pred_type, train_fold, gene_set, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # Average all folds and patients over each gene
  group_by(pred_type, gene, gene_set, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T))

pred_gene_msr_df <- pred_gene_msr_df%>%
  arrange(factor(model_id, levels=c("ST-Net", "HisToGene", "GeneCodeR", "DeepSpaCE", "DeepPT", "Hist2ST", "EGNv1", "EGNv2", "TCGN", "THItoGene", "iStar")))
pred_gene_msr_df$model_id <- factor(pred_gene_msr_df$model_id, levels = unique(pred_gene_msr_df$model_id))

pred_gene_msr_df %>%
  filter(pred_type == "test") %>%
  ggplot() +
  aes(y=log(var_exprs), x=log(var_exprs_orig)) +
  geom_point(size=1)+
  facet_wrap(~model_id,
             nrow=3, scales="free") +
  theme(strip.text.x = element_text(size = 10),
        strip.text.y = element_text(size = 10),
        axis.text.x = element_text(angle=25, hjust=1, vjust=1),
        legend.position = "bottom") + 
  th + 
  labs(x="log(GE Variance)", y="log(Normalised GE Variance)")

Investigate genes in heatmap - Supp Figure 5c

Scatterplot of average correlation difference between average correlation of both DeepPT/ST-Net & average correlation of both HisToGene/Hist2ST and the correlation between ground truth and normalisation (y-axis). Each point represents a gene.

# Hypothesis: Genes performing bad are when normalized expression different from 
# ground truth
filt_gene_df <- comb_pred_dat %>%
  filter(pred_type == "train" & train_fold == 2) %>%
  filter(model_id %in% c("DeepPT", "Hist2ST")) %>%
  mutate(model_id = ifelse(model_id == "DeepPT", 
                           "Ground Truth", "NormalisedGE")) %>%
  select(-pred,-row_id) %>%
  pivot_wider(names_from="model_id", values_from="exprs") %>%
  separate("patch_id",
           c("img_id", "patch_coord"),
           sep = "_",
           remove = FALSE) %>%
  separate("patch_coord", c("x", "y"), sep = "x") %>%
  mutate_at(c("x", "y"), as.numeric)

gt_norm_cor <- filt_gene_df %>%
  group_by(gene,img_id,pred_type,train_fold) %>%
  summarise(cor = cor(`Ground Truth`, `NormalisedGE`,use="pairwise.complete.obs")) %>%
  group_by(gene) %>%
  summarise(mean_cor = mean(cor,na.rm=T))
## `summarise()` has grouped output by 'gene', 'img_id', 'pred_type'. You can
## override using the `.groups` argument.
# Normalisation changes of spatial gene expression gives rise to difference in correlation
# result 
hm_dat %>% 
  pivot_longer(cols=!c("model_id","pred_type"),
               names_to="gene",
               values_to="cor") %>%
  mutate(grp = case_when(
    model_id %in% c("DeepPT", "ST-Net") ~ 1,
    model_id %in% c("Hist2ST", "HisToGene") ~ 2, 
    TRUE ~ NA
  )) %>%
  filter(!is.na(grp)) %>%
  group_by(gene, grp) %>%
  summarise(grp_mean_cor = mean(cor)) %>%
  pivot_wider(names_from="grp", values_from="grp_mean_cor") %>%
  mutate(cor_diff = `1`-`2`) %>%
  left_join(gt_norm_cor,by="gene") %>%
  ggplot() + 
  aes(x=cor_diff, y=mean_cor)+
  geom_point(alpha=0.6) +
  geom_smooth(method="lm")+
  th +
  theme(axis.title.x = element_text(size=12),
        axis.title.y = element_text(size=12)) +
  labs(y="Correlation between Ground Truth GE and Normalised GE",
       x="Average Cor. Diff between DeepPT/ST-Net & HistToGene/Hist2ST")
## `summarise()` has grouped output by 'gene'. You can override using the
## `.groups` argument.
## `geom_smooth()` using formula = 'y ~ x'

Figure 3e

Heatmap of average correlation of each gene in HER2+ ST dataset and each method. The log of the mean and variance of each gene are coloured above the heatmap.

# Check which cluster contains which genes
ha_dat <- ha_dat %>%
  arrange(desc(mean_exprs))

ha = HeatmapAnnotation(
  `log(mean)` = log10(ha_dat$mean_exprs),
  `log(var)` = log10(ha_dat$var_exprs),
  annotation_name_side = "left",
  col = list(
    `log(mean)` = colorRamp2(c(min(log10(ha_dat$mean_exprs)), max(log10(ha_dat$mean_exprs))), 
                          c( "white", "#004444")),
    `log(var)` = colorRamp2(c(min(log10(ha_dat$var_exprs)), max(log10(ha_dat$var_exprs))), 
                          c( "white", "#ce3b39"))
  ),
  show_legend = c(TRUE, TRUE, FALSE, FALSE,FALSE,FALSE,FALSE),
  annotation_legend_param = list(
    `log(mean)` = list(direction = "horizontal"),
    `log(var)` = list(direction = "horizontal")
    )
)

#pdf("/dskh/nobackup/chuhanw/Benchmarking-HE-ST/accept_figures/fig_3e.pdf", width = 16, height = 9)

p_heatmap <- Heatmap(
  hm_mat,
  col = col_fun,
  column_labels = rep("", ncol(hm_mat)),
  name = "Average Correlation",
  # column_title = "Gene",
  top_annotation = ha,
  heatmap_legend_param = list(direction = "horizontal"),
  row_names_gp = gpar(fontsize = 16)
  #cluster_columns = FALSE
)
   
p_heatmap_botleg <- draw(p_heatmap, 
     merge_legend = TRUE, 
     heatmap_legend_side = "bottom", 
     annotation_legend_side = "bottom"
     ) 

#dev.off()

Figure 3f - Visium-Kidney

PCC violin and boxplots for each method in Visium-Kidney dataset for all genes as well as for HVGs, SVGs and HSGs only. Significance between HVGs, SVGs, HSGs and all genes are calculated using Wilcoxon rank-sum test

pred_cor_hvg_df_kidney <- pred_gene_msr_df_kidney %>%
  pivot_longer(cols=c("cor_pearson"), 
               names_to="metric",
               values_to="value") %>%
  mutate(metric = factor(metric, levels = c("cor_pearson"))) %>%
  mutate(metric=ifelse(metric=="ssim","SSIM",
                       ifelse(metric=="cor_pearson","PCC",metric)))

sigFunc <- function(test_result, group) {
  if (test_result$p.value < 0.001) {
    signif <- "***"
  } else if (test_result$p.value < 0.01) {
    signif <- "**"
  } else if (test_result$p.value < 0.05) {
    signif <- "*"
  } else if (test_result$p.value < 0.1) {
    signif <- "\u00b0"
  } else {
    signif <- NA
  }

  stat_report <- paste0(
    "italic(p[", group, "]) == ", sprintf("%.2g", test_result$p.value)
  )

  list(signif = signif, stat_report = stat_report)
}

ptest_HVG <- pred_cor_hvg_df_kidney %>%
  filter(gene_set %in% c("All Genes", "HVGs")) %>%
  group_by(model_id, metric) %>%
  do({
    test_result <- wilcox.test(value ~ gene_set, data = ., alternative="less")
    sig <- sigFunc(test_result, group = "hvg")
    data.frame(
      signif = sig$signif,
      stat_report = sig$stat_report,
      xmin = "All Genes",
      xmax = "HVGs",
      max_val = max(.$value)
    )
  })

ptest_SVG <- pred_cor_hvg_df_kidney %>%
  filter(gene_set %in% c("All Genes", "SVGs")) %>%
  group_by(model_id, metric) %>%
  do({
    test_result <- wilcox.test(value ~ gene_set, data = ., alternative = "less")
    sig <- sigFunc(test_result, group = "svg")
    data.frame(
      signif = sig$signif,
      stat_report = sig$stat_report,
      xmin = "All Genes",
      xmax = "SVGs",
      max_val = max(.$value)
    )
  })

ptest_sparse <- pred_cor_hvg_df_kidney %>%
  filter(gene_set %in% c("All Genes", "HSGs")) %>%
  group_by(model_id, metric) %>%
  do({
    test_result <- wilcox.test(value ~ gene_set, data = ., alternative = "less")
    sig <- sigFunc(test_result, group = "hsg") 
    data.frame(
      signif = sig$signif,
      stat_report = sig$stat_report,
      xmin = "All Genes",
      xmax = "HSGs",
      max_val = max(.$value)
    )
  })

desired_order <- c("All Genes", "HVGs", "SVGs", "HSGs")
pred_cor_hvg_df_kidney <- pred_cor_hvg_df_kidney %>%
  mutate(gene_set = factor(gene_set, levels = desired_order))

kidney_p_hvg_cor <- pred_cor_hvg_df_kidney %>%
  ggplot() +
  aes(x = gene_set,
      y = value,
      fill = gene_set,
      col = gene_set) +
  geom_violin(alpha = 0.2,
              width = 1, 
              position = position_dodge(width = 0.5)) +
  geom_boxplot(alpha = 0.5,
               width = 0.3, 
               position = position_dodge(width = 0.5)) +
  scale_color_manual(values = c("All Genes" = "#4393C3", "HVGs" = "#009E73", "SVGs" = "#f6b000", "HSGs" = "#ee8227")) + 
  scale_fill_manual(values = c("All Genes" = "#4393C3", "HVGs" = "#009E73", "SVGs" = "#f6b000", "HSGs" = "#ee8227")) +
  geom_signif(
    aes(
      xmin = xmin,
      xmax = xmax,
      y_position = max_val + max_val / 10,
      annotations = signif
    ),
    color = "grey70",
    textsize = 4.5,
    vjust = 1.5,
    margin_top = 0.15,
    data = ptest_HVG,
    manual = TRUE,
    inherit.aes = FALSE
  ) +
  geom_signif(
    aes(
      xmin = xmin,
      xmax = xmax,
      y_position = max_val + max_val / 4.5,
      annotations = signif
    ),
    color = "grey40",
    textsize = 4.5,
    vjust = 1.5,
    margin_top = 0.15,
    data = ptest_SVG,
    manual = TRUE,
    inherit.aes = FALSE
  ) +
  geom_signif(
    aes(
      xmin = xmin,
      xmax = xmax,
      y_position = max_val + max_val / 4.5,
      annotations = signif
    ),
    color = "grey20",
    textsize = 4.5,
    vjust = 1.5,
    margin_top = 0.15,
    data = ptest_sparse,
    manual = TRUE,
    inherit.aes = FALSE
  ) + 
  geom_text(
    aes(
      x = as.numeric(factor(xmin)) + 1.6,
      y = -0.14,
      label = stat_report
    ),
    data = ptest_HVG,
    inherit.aes = FALSE,
    size = 2.5, 
    parse = TRUE, 
    color = "grey20" 
  ) +
  geom_text(
    aes(
      x = as.numeric(factor(xmax)) + 1.6, 
      y = -0.17, 
      label = stat_report
    ),
    data = ptest_SVG,
    inherit.aes = FALSE,
    size = 2.5,
    parse = TRUE, 
    color = "grey20"
  ) +
  geom_text(
    aes(
      x = as.numeric(factor(xmax)) + 1.6, 
      y = -0.20, 
      label = stat_report
    ),
    data = ptest_sparse,
    inherit.aes = FALSE,
    size = 2.5,
    parse = TRUE, 
    color = "grey20"
    ) +
  facet_grid(cols=vars(model_id),rows=vars(metric),
             switch="y") +
  theme(
    legend.position = "bottom",
    strip.placement = "outside",
    axis.ticks.x=element_blank(),
    axis.text.x=element_blank(),
    strip.text.x = element_text(size = 14),
    strip.text.y = element_text(size = 16),
    panel.grid.major = element_blank(),
    panel.grid.minor = element_blank(),
  ) +
  labs(x="",y="",
       col="Gene Set",fill="Gene Set")
## Warning in geom_signif(aes(xmin = xmin, xmax = xmax, y_position = max_val + :
## Ignoring unknown aesthetics: xmin, xmax, y_position, and annotations
## Warning in geom_signif(aes(xmin = xmin, xmax = xmax, y_position = max_val + :
## Ignoring unknown aesthetics: xmin, xmax, y_position, and annotations
## Warning in geom_signif(aes(xmin = xmin, xmax = xmax, y_position = max_val + :
## Ignoring unknown aesthetics: xmin, xmax, y_position, and annotations
kidney_p_hvg_cor
## Warning: Removed 6 rows containing missing values or values outside the scale range
## (`geom_signif()`).
## Warning: Removed 9 rows containing missing values or values outside the scale range
## (`geom_signif()`).
## Warning: Removed 18 rows containing missing values or values outside the scale range
## (`geom_signif()`).

Figure 3g - Visium-HER2+

PCC violin and boxplots for each method in Visium-HER2+ dataset for all genes as well as for HVGs, SVGs and HSGs only. Significance between HVGs, SVGs, HSGs and all genes are calculated using Wilcoxon rank-sum test

pred_cor_hvg_df_whole_bc <- pred_gene_msr_df_whole_bc %>%
  pivot_longer(cols="cor_pearson",
               names_to="metric",
               values_to="value") %>%
  mutate(metric = factor(metric, levels = "cor_pearson")) %>%
  mutate(metric=ifelse(metric=="ssim","SSIM",
                       ifelse(metric=="cor_pearson","PCC",metric)))

sigFunc <- function(test_result, group) {
  if (test_result$p.value < 0.001) {
    signif <- "***"
  } else if (test_result$p.value < 0.01) {
    signif <- "**"
  } else if (test_result$p.value < 0.05) {
    signif <- "*"
  } else if (test_result$p.value < 0.1) {
    signif <- "\u00b0"
  } else {
    signif <- NA
  }

  stat_report <- paste0(
    "italic(p[", group, "]) == ", sprintf("%.2g", test_result$p.value)
  )

  list(signif = signif, stat_report = stat_report)
}

ptest_HVG <- pred_cor_hvg_df_whole_bc %>%
  filter(gene_set %in% c("All Genes", "HVGs")) %>%
  group_by(model_id,metric)  %>%
  do({
    test_result <- wilcox.test(value ~ gene_set, data = ., alternative = "less")
    sig <- sigFunc(test_result, group = "hvg")
    data.frame(
      signif = sig$signif,
      stat_report = sig$stat_report,
      xmin = "All Genes",
      xmax = "HVGs",
      max_val = max(.$value)
    )
  })

ptest_SVG <- pred_cor_hvg_df_whole_bc %>%
  filter(gene_set %in% c("All Genes", "SVGs")) %>%
  group_by(model_id,metric)  %>%
  do({
    test_result <- wilcox.test(value ~ gene_set, data = ., alternative = "less")
    sig <- sigFunc(test_result, group = "svg")
    data.frame(
      signif = sig$signif,
      stat_report = sig$stat_report,
      xmin = "All Genes",
      xmax = "SVGs",
      max_val = max(.$value)
    )
  })

ptest_sparse <- pred_cor_hvg_df_whole_bc %>%
  filter(gene_set %in% c("All Genes", "HSGs")) %>%
  group_by(model_id,metric)  %>%
  do({
    test_result <- wilcox.test(value ~ gene_set, data = ., alternative = "less")
    sig <- sigFunc(test_result, group = "hsg")
    data.frame(
      signif = sig$signif,
      stat_report = sig$stat_report,
      xmin = "All Genes",
      xmax = "HSGs",
      max_val = max(.$value)
    )
  })

desired_order <- c("All Genes", "HVGs", "SVGs", "HSGs")
pred_cor_hvg_df_whole_bc <- pred_cor_hvg_df_whole_bc %>%
  mutate(gene_set = factor(gene_set, levels = desired_order))

whole_bc_p_hvg_cor <- pred_cor_hvg_df_whole_bc %>%
  ggplot() +
  aes(x = gene_set,
      y = value,
      fill = gene_set,
      col = gene_set) +
  geom_violin(alpha = 0.2,
              width = 1,
              position = position_dodge(width = 0.6)) +
  geom_boxplot(alpha = 0.5,
               width = 0.3,
               position = position_dodge(width = 0.6)) +
  scale_color_manual(values = c("All Genes" = "#4393C3", "HVGs" = "#009E73", "SVGs" = "#f6b000", "HSGs" = "#ee8227")) + 
  scale_fill_manual(values = c("All Genes" = "#4393C3", "HVGs" = "#009E73", "SVGs" = "#f6b000", "HSGs" = "#ee8227")) +
    geom_signif(
    aes(
      xmin = xmin,
      xmax = xmax,
      y_position = max_val + max_val / 10,
      annotations = signif
    ),
    color = "grey70",
    textsize = 4.5,
    vjust = 1.5,
    margin_top = 0.15,
    data = ptest_HVG,
    manual = TRUE,
    inherit.aes = FALSE
  ) +
  geom_signif(
    aes(
      xmin = xmin,
      xmax = xmax,
      y_position = max_val + max_val / 4.5,
      annotations = signif
    ),
    color = "grey40",
    textsize = 4.5,
    vjust = 1.5,
    margin_top = 0.15,
    data = ptest_SVG,
    manual = TRUE,
    inherit.aes = FALSE
  ) +
  geom_signif(
    aes(
      xmin = xmin,
      xmax = xmax,
      y_position = max_val + max_val / 4.5,
      annotations = signif
    ),
    color = "grey20",
    textsize = 4.5,
    vjust = 1.5,
    margin_top = 0.15,
    data = ptest_sparse,
    manual = TRUE,
    inherit.aes = FALSE
  ) + 
  geom_text(
    aes(
      x = as.numeric(factor(xmin)) + 1.6,
      y = -0.30,
      label = stat_report
    ),
    data = ptest_HVG,
    inherit.aes = FALSE,
    size = 2.5, 
    parse = TRUE, 
    color = "grey20" 
  ) +
  geom_text(
    aes(
      x = as.numeric(factor(xmax)) + 1.6, 
      y = -0.33, 
      label = stat_report
    ),
    data = ptest_SVG,
    inherit.aes = FALSE,
    size = 2.5,
    parse = TRUE, 
    color = "grey20"
  ) +
  geom_text(
    aes(
      x = as.numeric(factor(xmax)) + 1.6, 
      y = -0.36, 
      label = stat_report
    ),
    data = ptest_sparse,
    inherit.aes = FALSE,
    size = 2.5,
    parse = TRUE, 
    color = "grey20"
    ) +
  facet_grid(cols=vars(model_id),rows=vars(metric),
             switch="y") +
  theme(
    legend.position = "bottom",
    strip.placement = "outside",
    axis.ticks.x=element_blank(),
    axis.text.x=element_blank(),
    strip.text.x = element_text(size = 14),
    strip.text.y = element_text(size = 16),
    panel.grid.major = element_blank(),
    panel.grid.minor = element_blank(),
  ) +
  labs(x="",y="",
       col="Gene Set",fill="Gene Set")
## Warning in geom_signif(aes(xmin = xmin, xmax = xmax, y_position = max_val + :
## Ignoring unknown aesthetics: xmin, xmax, y_position, and annotations
## Warning in geom_signif(aes(xmin = xmin, xmax = xmax, y_position = max_val + :
## Ignoring unknown aesthetics: xmin, xmax, y_position, and annotations
## Warning in geom_signif(aes(xmin = xmin, xmax = xmax, y_position = max_val + :
## Ignoring unknown aesthetics: xmin, xmax, y_position, and annotations
whole_bc_p_hvg_cor
## Warning: Removed 12 rows containing missing values or values outside the scale range
## (`geom_signif()`).
## Warning: Removed 9 rows containing missing values or values outside the scale range
## (`geom_signif()`).
## Warning: Removed 12 rows containing missing values or values outside the scale range
## (`geom_signif()`).

Supplementary Figure 3 - HER2+ Clustering

# do one sample each time
# load K-means clustering results
df_clustered_comb <- readRDS("./data/processed/her2st/her2st_cluster_11.rds")
desired_order <- c(
  "Ground Truth Annotation", "Ground Truth SGE", "ST-Net", "HisToGene", "GeneCodeR", "DeepSpaCE",
  "DeepPT", "Hist2ST", "EGNv1", "EGNv2", "TCGN", "THItoGene", "iStar"
)
cluster_sample <- df_clustered_comb %>%
  filter(img_id == "B1") %>% #choose image sample here
  separate("patch_id",
           c("img_id", "patch_coord"),
           sep = "_",
           remove = FALSE) %>%
  separate("patch_coord", c("x", "y"), sep = "x")

cluster_sample <- cluster_sample %>%
  mutate(model_id = case_when(
  model_id == "genecoder_i500_j500" ~ "GeneCodeR",
  TRUE ~ model_id
  )) %>%
  bind_rows(
    cluster_sample %>%
      filter(model_id == "DeepPT") %>%
      mutate(cluster = gt_cluster) %>%
      mutate(model_id = "Ground Truth Annotation")
  ) %>%
    bind_rows(
    cluster_sample %>%
      filter(model_id == "DeepPT") %>%
      mutate(cluster = cluster_observed) %>%
      mutate(model_id = "Ground Truth SGE")
  ) %>%
  mutate_at(c("x", "y"), as.numeric) %>%
  group_by(model_id)

#Calculate ARIs
ari_per_model <- cluster_sample %>%
  group_by(model_id) %>%
  summarise(ari = adjustedRandIndex(cluster, gt_cluster)) %>%
  ungroup() %>%
  mutate(model_label = paste0(as.character(model_id), "\nARI=", round(ari, 3))) %>%
  mutate(model_label = factor(model_label, levels = model_label[order(factor(model_id, levels = desired_order))])) %>%
  arrange(factor(model_id, levels = desired_order))

cluster_sample$model_id <- factor(cluster_sample$model_id, levels = desired_order)

fig_region_cluster <- cluster_sample %>%
  mutate(cluster = as.factor(cluster))%>%
  left_join(ari_per_model %>% select(model_id, model_label), by = "model_id") %>%
  mutate(model_label = factor(model_label, levels = ari_per_model$model_label)) %>%
  mutate(model_id = factor(
    model_id,
    levels = c(
      "DeepPT","DeepSpaCE",
      "EGNv1", "EGNv2", "TCGN", "THItoGene", "iStar",
      "GeneCodeR","Hist2ST","HisToGene","ST-Net", "Ground Truth Annotation", "Ground Truth SGE"
    )
  )) %>%
ggplot() +
  aes(x = x, y = -y, color = cluster) +
  geom_point(size = 0.88) +
  facet_wrap(~ model_label, scales = "free", ncol = 13) +
  scale_color_viridis_d() +
  labs(color = "Cluster", x = "", y = "") +
  theme(
    axis.ticks = element_blank(),
    axis.text = element_blank(),
    axis.line = element_blank(),
    panel.grid.major = element_blank(),
    panel.grid.minor = element_blank(),
    legend.text = element_text(size = 10),
    legend.key.size = unit(15, 'points'),
    legend.title = element_text(size = 12),
    strip.text.x = element_text(size = 12)
  ) 

fig_region_cluster

% Zero vs Cor - Supplementary Figure 4

Gene expression prediction evaluation metrics vs. the percentage of zeros in each gene for each method. Linear lines of best fit are plotted for each.

exprs_perc_zero_df <- deeppt_exprs_df %>%
  rename(patch_id = X) %>%
  pivot_longer(cols = !c("patch_id"),
               names_to = "gene",
               values_to = "exprs") %>%
  separate(
      patch_id ,
      into = c("img_id", "x_y"),
      sep = "_",
      convert = TRUE,
      remove=FALSE
    ) %>%
  separate(x_y, into = c("x", "y"), sep = "x") %>%
  group_by(img_id, gene) %>%
  summarise(perc_zero = mean(exprs==0), .groups="drop")

exprs_perc_zero_metric_df <- pred_feat_cor %>%
  filter(pred_type=="test") %>%
  left_join(
    exprs_perc_zero_df,
    by=c("img_id","gene")
  ) %>%
  group_by(gene, model_id, pred_type) %>%
  summarise_if(is.numeric, function(col) mean(col, na.rm=T)) %>%
  pivot_longer(
    cols=c("cor_pearson","mi","nrmse_sd","ssim","auc_0"),
    names_to="metric",
    values_to="value"
  ) %>%
  filter(is.finite(value))

exprs_perc_zero_metric_df <- exprs_perc_zero_metric_df%>%
  arrange(factor(model_id, levels=c("ST-Net", "HisToGene", "GeneCodeR", "DeepSpaCE", "DeepPT", "Hist2ST", "EGNv1", "EGNv2", "TCGN", "THItoGene", "iStar")))
exprs_perc_zero_metric_df$model_id <- factor(exprs_perc_zero_metric_df$model_id, levels = unique(exprs_perc_zero_metric_df$model_id))

# Calculate correlation statistics and prepare report
exprs_perc_zero_metric_res <- exprs_perc_zero_metric_df %>%
  group_by(model_id, pred_type, metric) %>%
  do({
    test_result <- broom::tidy(cor.test(.$value, .$perc_zero))
    test_result$n <- nrow(.)
    test_result
  }) %>%
  mutate(
    report = sprintf(
      "italic(t)~(%d)~'='~%.2f*','~italic(p)~'='~%.2g~','~italic(r)~'='~%.2f",
      n - 2, statistic, p.value, estimate
    )
  )

exprs_perc_zero_metric_df %>% 
  ggplot() +
  aes(x = perc_zero, y = value) +
  geom_point(size = 1, alpha = 0.5) +
  geom_smooth(method = "lm", formula = "y~x") +
  geom_text(
    data = exprs_perc_zero_metric_res,
    aes(label = report, x = Inf, y = Inf),
    hjust = 1, vjust = 2, size = 2.3,
    col = "grey30",
    parse = TRUE
  ) + 
  theme(
    panel.grid.major = element_blank(),
    panel.grid.minor = element_blank()
  ) +
  facet_wrap(~metric + model_id, nrow = 5, scales = "free")

Supplementary Figure 5a

# Plot predicted vs truth for top genes for one image
filt_gene_df <- comb_pred_dat %>%
  filter(pred_type == "test") %>%
  filter(img_id == "C1" &
           gene %in% "BGN") %>% # "C3"
  separate("patch_id",
           c("img_id", "patch_coord"),
           sep = "_",
           remove = FALSE) %>%
  separate("patch_coord", c("x", "y"), sep = "x")

# filt_gene_df_norm_BGN <- filt_gene_df_norm %>%
#  filter(gene == "BGN", img_id=="C1")

filt_gene_df <- filt_gene_df %>%
  bind_rows(
    filt_gene_df %>%
      filter(model_id == "DeepPT") %>%
      mutate(pred = exprs) %>%
      mutate(model_id = "Ground Truth")
  ) %>%
 bind_rows(
   filt_gene_df %>%
     filter(model_id == "Hist2ST") %>%
     mutate(pred = exprs) %>%
     #mutate(pred = filt_gene_df_norm_BGN$NormalisedGE) %>%
     mutate(model_id = "Hist2ST-NormalisedGE")
 ) %>%
  mutate_at(c("x", "y"), as.numeric) %>%
  group_by(model_id, gene) %>%
  mutate(pred = minMaxScaler(pred))

supp5_a_exprs <- filt_gene_df %>%
  mutate(model_id = factor(
    model_id,
    levels = c(
      "Ground Truth","ST-Net", "HisToGene", "GeneCodeR", "DeepSpaCE", "DeepPT", "Hist2ST", "EGNv1", "EGNv2", "TCGN", "THItoGene", "iStar", "Hist2ST-NormalisedGE")
    
  )) %>%
  ggplot() +
  aes(x = x, y = -y, col = pred) +
  geom_point(size=1.5) +
  facet_wrap(~ model_id, scales = "free",
             ncol = 13, nrow = 1) +
  # scale_color_gradient2()+
  scale_color_viridis_c() +
  labs(col = "Scaled Expression",x="",y="") +
  theme(
    # legend.position = "bottom"
    axis.ticks = element_blank(),
    axis.text = element_blank(),
    axis.line = element_blank(),
    panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
    legend.text = element_text(size=7),
    legend.key.size = unit(12, 'points'),
    legend.title = element_text(size=9),
    strip.text.x = element_text(size=12)
  )

supp5_a_exprs

Supplementary figure 7

Visium-Kidney - coefficient of variation of each ground truth and predicted genes

#load coefficient of variation of genes
kidney_variance_data_combine <- readRDS("data/processed/visium/variance_data_combine_kidney.rds")%>%
  mutate(Group = case_when(
    Group == "ST_Net" ~ "ST-Net",
    TRUE ~ Group
  ))

kidney_variance_data_combine$Group <- factor(kidney_variance_data_combine$Group, levels = c("ST-Net", "HisToGene", "GeneCodeR", "DeepSpaCE", "DeepPT", "Hist2ST", "EGNv1", "EGNv2", "TCGN", "THItoGene", "iStar"))
kidney_variance_data_combine$Group <- as.factor(kidney_variance_data_combine$Group)
kidney_variance_data_combine$Type <- factor(kidney_variance_data_combine$Type, levels = c("GT 992 genes", "Predicted 992 genes", "GT 145 genes", "Predicted 145 genes"))
kidney_variance_data_combine$Type <- as.factor(kidney_variance_data_combine$Type)

kidney_gene_variation <- ggplot(kidney_variance_data_combine, aes(x = Type, y = Variance, fill = Group, col=Group)) +
    geom_boxplot(position = position_dodge(width = 0.8), width = 0.6, alpha = 0.5) +
    #geom_tile(aes(x = Type, y = -0.5, fill = Type), height = 0.3, color = "black") +
    facet_wrap(~Group, nrow = 3, ncol = 4) +
    labs(title = "", x ="", y = "") +
    theme_minimal() +
    scale_fill_manual(values = dl_method_pal) +
    scale_color_manual(values = dl_method_pal) +
    theme(legend.title = element_blank()) +
    theme(axis.text.x = element_text(angle = 35, hjust = 0.86, size = 5.6)) +
    theme(panel.grid.major = element_blank()) + 
    theme(panel.grid.minor = element_blank()) + 
    theme(panel.background = element_rect(colour = "black", size=0.7, fill=NA))

kidney_gene_variation

Supplementary figure 8

Visium-HER2+ - coefficient of variation of each ground truth and predicted genes

#load coefficient of variation of genes
whole_bc_variance_data_combine <- readRDS("data/processed/visium/whole_bc_variance_data_combine.rds")%>%
  mutate(Group = case_when(
    Group == "ST_Net" ~ "ST-Net",
    TRUE ~ Group
  ))
#sort model order
whole_bc_variance_data_combine$Group <- factor(whole_bc_variance_data_combine$Group, levels = c("ST-Net", "HisToGene", "GeneCodeR", "DeepSpaCE", "DeepPT", "Hist2ST", "EGNv1", "EGNv2", "TCGN", "THItoGene"))
whole_bc_variance_data_combine$Group <- as.factor(whole_bc_variance_data_combine$Group)
whole_bc_variance_data_combine$Type <- factor(whole_bc_variance_data_combine$Type, levels = c("GT 990 genes", "Predicted 990 genes", "GT 274 genes", "Predicted 274 genes"))
whole_bc_variance_data_combine$Type <- as.factor(whole_bc_variance_data_combine$Type)
#plot gene variation across all spots
whole_bc_gene_variation <- ggplot(whole_bc_variance_data_combine, aes(x = Type, y = Variance, fill = Group, col = Group)) +
    geom_boxplot(position = position_dodge(width = 0.8), width = 0.6, alpha = 0.5) +
    facet_wrap(~Group, nrow = 3, ncol = 4) +
    labs(title = "", x ="", y = "") +
    theme_minimal() +
    scale_fill_manual(values = dl_method_pal) +
    scale_color_manual(values = dl_method_pal) +
    theme(legend.title = element_blank()) +
    theme(axis.text.x = element_text(angle = 35, hjust = 0.86, size = 5.6)) +
    theme(panel.grid.major = element_blank()) + 
    theme(panel.grid.minor = element_blank()) + 
    theme(panel.background = element_rect(colour = "black", size=0.7, fill=NA))

whole_bc_gene_variation

Supplementary Figure 9 - datasets sparsity

#load processed gene expression files of all datasets 
processed_expr_her2 <- read.csv("data/processed/her2st/processed_expr.csv", row.names = 1)
processed_expr_cscc <- read.csv("data/processed/cscc/processed_expr.csv", row.names = 1)
processed_expr_visium_bc_762 <- read.csv("data/processed/visium/processed_expr_visium_bc_762.csv", row.names = 1)
processed_expr_visium_bc_990 <- read.csv("data/processed/visium/processed_visium_bc_expr_990.csv", row.names = 1)
processed_expr_kidney <- read.csv("data/processed/visium/processed_expr_kidney.csv", row.names = 1)
processed_expr_whole_bc <- read.csv("data/processed/visium/processed_whole_bc_expr.csv", row.names = 1)

# calculate gene matrix sparsity
sparsity_list <- list(
  her2 = colMeans(processed_expr_her2 == 0),
  cscc = colMeans(processed_expr_cscc == 0),
  visium_bc_762_genes = colMeans(processed_expr_visium_bc_762 == 0),
  visium_bc_990_genes = colMeans(processed_expr_visium_bc_990 == 0),
  kidney = colMeans(processed_expr_kidney == 0),
  whole_bc = colMeans(processed_expr_whole_bc == 0)
)

sparsity_data <- do.call(rbind, lapply(names(sparsity_list), function(group) {
  data.frame(Sparsity = sparsity_list[[group]], Group = group)
}))

sparsity_data$Group <- factor(sparsity_data$Group, levels = c("her2", "cscc", "visium_bc_762_genes", 'visium_bc_990_genes', "whole_bc", "kidney"))

fig_sparsity <- ggplot(sparsity_data, aes(x = Group, y = Sparsity, fill = Group, col=Group)) +
    geom_boxplot(position = position_dodge(width = 0.8), width = 0.3, alpha = 0.5) +
    labs(title = "", x ="", y = "") +
    theme_minimal() +
    scale_fill_manual(values = c("her2" = "#EDB8B0", "cscc" = "#88CEE6", 
                               "kidney" = "#FDBF6F", "visium_bc_762_genes" = "#B2D3A4", 
                               "visium_bc_990_genes" = "#92A5D1", "whole_bc" = "#B696B6")) +
    scale_color_manual(values = c("her2" = "#EDB8B0", "cscc" = "#88CEE6", 
                               "kidney" = "#FDBF6F", "visium_bc_762_genes" = "#B2D3A4", 
                               "visium_bc_990_genes" = "#92A5D1", "whole_bc" = "#B696B6")) +
    theme(legend.title = element_blank()) +
    theme(axis.text.x = element_text(angle = 0, hjust = 0.5, size = 10)) +
    theme(panel.grid.major = element_blank()) + 
    theme(panel.grid.minor = element_blank()) + 
    theme(panel.background = element_rect(colour = "black", size=0.7, fill=NA))

fig_sparsity

Supplementary figure 10 - Kidney 992 HVGs + 145 HSGs

# Plot Boxplots of Correlations
pred_feat_cor_kidney_145 <- pred_feat_cor_kidney_145 %>%
  mutate(Type = paste0(model_id, " 145 genes"))
pred_feat_cor_kidney_992 <- pred_feat_cor_kidney_992 %>%
  mutate(Type = paste0(model_id, " 992 genes"))

pred_feat_cor_kidney <- rbind(pred_feat_cor_kidney_145, pred_feat_cor_kidney_992)
pred_feat_cor_kidney <- pred_feat_cor_kidney %>%
  select(gene, train_fold, model_id,
         img_id, cor_pearson, mi, ssim, auc_0, Type) %>%
  # Average image over each patient
  group_by(gene, model_id, Type) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # # Average all folds and patients over each gene
  group_by(gene, model_id, Type) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T))

th <-   theme(text=element_text(size=12),
                axis.text.x = element_text(angle = 0, hjust = 0.5),
                panel.grid.major = element_blank(),
                panel.grid.minor = element_blank(),
                panel.background = element_rect(colour = "black", size=0.7, fill=NA) )

pred_feat_cor_kidney <- pred_feat_cor_kidney%>%
  arrange(factor(Type, levels=c("iStar 145 genes", "iStar 992 genes", "THItoGene 145 genes", "THItoGene 992 genes", "TCGN 145 genes", "TCGN 992 genes", "EGNv2 145 genes", "EGNv2 992 genes", "EGNv1 145 genes", "EGNv1 992 genes", "Hist2ST 145 genes", "Hist2ST 992 genes", "DeepPT 145 genes", "DeepPT 992 genes", "DeepSpaCE 145 genes", "DeepSpaCE 992 genes", "GeneCodeR 145 genes", "GeneCodeR 992 genes", "HisToGene 145 genes", "HisToGene 992 genes", "ST-Net 145 genes", "ST-Net 992 genes")))
pred_feat_cor_kidney$Type <- factor(pred_feat_cor_kidney$Type, levels = unique(pred_feat_cor_kidney$Type))

fig_kidney <- pred_feat_cor_kidney %>%
  pivot_longer(cols=c("cor_pearson","mi", "ssim","auc_0"),
               names_to="metric",
               values_to="value") %>%
  mutate(metric = factor(metric, levels = c("cor_pearson","mi",
                                            "js_div","nrmse_sd","ssim", "auc_0"))) %>%
  mutate(metric=case_when(
    metric=="cor_pearson" ~ "PCC",
    metric=="mi" ~ "MI",
    metric=="ssim" ~ "SSIM",
    metric=="auc_0"~"AUC",
    TRUE ~ "other"
  ) %>%
    factor(., levels=c("PCC","MI","SSIM","AUC"))) %>%
  ggplot() +
  aes(x=Type, y=value, fill=model_id, col=model_id) +
  geom_violin(alpha=0.2) +
  geom_boxplot(alpha=0.5, width=0.3) +
  facet_wrap(~metric, nrow=2, scales="free") +
  coord_flip()+
  scale_color_manual(values=dl_method_pal)+
  scale_fill_manual(values=dl_method_pal)+
  theme(legend.position = "none") +
  th + 
  labs(x="", y="")

fig_kidney
## Warning: Removed 1 row containing non-finite outside the scale range
## (`stat_ydensity()`).
## Warning: Removed 1 row containing non-finite outside the scale range
## (`stat_boxplot()`).

Supplementary figure 11 - Whole BC 990 HVGs + 274 HSGs

# Plot Boxplots of Correlations
pred_feat_cor_whole_bc_274 <- pred_feat_cor_whole_bc_274 %>%
  mutate(Type = paste0(model_id, " 274 genes"))
pred_feat_cor_whole_bc_990 <- pred_feat_cor_whole_bc_990 %>%
  mutate(Type = paste0(model_id, " 990 genes"))
pred_feat_cor_whole_bc <- rbind(pred_feat_cor_whole_bc_274, pred_feat_cor_whole_bc_990)

pred_feat_cor_whole_bc<- pred_feat_cor_whole_bc %>%
  filter(pred_type == "test") %>%
  select(gene, pred_type, train_fold, model_id,
         img_id, cor_pearson, mi, ssim, auc_0, Type) %>%
  # Average image over each patient
  group_by(gene, pred_type, train_fold, model_id, Type) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # # Average all folds and patients over each gene
  group_by(pred_type, gene, model_id, Type) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T))

th <-   theme(text=element_text(size=12),
                axis.text.x = element_text(angle = 0, hjust = 0.5),
                panel.grid.major = element_blank(),
                panel.grid.minor = element_blank(),
                panel.background = element_rect(colour = "black", size=0.7, fill=NA) )

pred_feat_cor_whole_bc <- pred_feat_cor_whole_bc%>%
  arrange(factor(Type, levels=c("iStar 274 genes", "iStar 990 genes", "THItoGene 274 genes", "THItoGene 990 genes", "TCGN 274 genes", "TCGN 990 genes", "EGNv2 274 genes", "EGNv2 990 genes", "EGNv1 274 genes", "EGNv1 990 genes", "Hist2ST 274 genes", "Hist2ST 990 genes", "DeepPT 274 genes", "DeepPT 990 genes", "DeepSpaCE 274 genes", "DeepSpaCE 990 genes", "GeneCodeR 274 genes", "GeneCodeR 990 genes", "HisToGene 274 genes", "HisToGene 990 genes", "ST-Net 274 genes", "ST-Net 990 genes")))
pred_feat_cor_whole_bc$Type <- factor(pred_feat_cor_whole_bc$Type, levels = unique(pred_feat_cor_whole_bc$Type))

fig_whole_bc <- pred_feat_cor_whole_bc %>%
  pivot_longer(cols=c("cor_pearson","mi", "ssim","auc_0"),
               names_to="metric",
               values_to="value") %>%
  mutate(metric = factor(metric, levels = c("cor_pearson","mi",
                                            "js_div","nrmse_sd","ssim", "auc_0"))) %>%
  mutate(metric=case_when(
    metric=="cor_pearson" ~ "PCC",
    metric=="mi" ~ "MI",
    metric=="ssim" ~ "SSIM",
    metric=="auc_0"~"AUC",
    TRUE ~ "other"
  ) %>%
    factor(., levels=c("PCC","MI","SSIM","AUC"))) %>%
  ggplot() +
  aes(x=Type, y=value, fill=model_id, col=model_id) +
  geom_violin(alpha=0.2) +
  geom_boxplot(alpha=0.5, width=0.3) +
  facet_wrap(~metric, nrow=2, scales="free") +
  coord_flip()+
  scale_color_manual(values=dl_method_pal)+
  scale_fill_manual(values=dl_method_pal)+
  theme(legend.position = "none") +
  th + 
  labs(x="", y="")

fig_whole_bc
## Warning: Removed 21 rows containing non-finite outside the scale range
## (`stat_ydensity()`).
## Warning: Removed 21 rows containing non-finite outside the scale range
## (`stat_boxplot()`).

Supplementary Figure 13 - HER2+ trained model on Visium-HER2+ (external validation)

pred_feat_cor_her2_bc <- pred_feat_cor_her2_bc%>%
  arrange(factor(model_id, levels=c("ST-Net", "HisToGene", "GeneCodeR", "DeepSpaCE", "DeepPT", "Hist2ST", "EGNv1", "EGNv2", "TCGN", "THItoGene")))
pred_feat_cor_her2_bc$model_id <- factor(pred_feat_cor_her2_bc$model_id, levels = unique(pred_feat_cor_her2_bc$model_id))

p_pat_cor <- pred_feat_cor_her2_bc %>%
  ggplot()+
  aes(x=model_id,y=cor_pearson, col=model_id, fill=model_id)+
  geom_violin(alpha=0.2)+
  geom_boxplot(alpha=0.65, width=0.3)+
  labs(y="PCC", x="")+
  scale_color_manual(values=dl_method_pal)+
  scale_fill_manual(values=dl_method_pal)+
  theme(legend.position="none")+
  th

p_pat_cor

QC Metrics

st_her2_qc_dat <- read.delim("data/raw/qc_her2st_results.tsv",
                               sep = "\t", header=TRUE, skip=5, row.names = NULL) %>%
  `colnames<-`(colnames(.)[c(2:ncol(.), 1)]) %>%
  dplyr::select(-comments, -row.names) 

st_her2_qc_df <- st_her2_qc_dat %>%
  mutate(pat_id = gsub(".tiff", "", filename)) %>%
  dplyr::select(-c(filename, image_bounding_box, base_mag, 
            type, levels, comment,warnings, mpp_x, mpp_y,
            pixels_to_use, height, width))

st_cscc_qc_dat <- read.delim("data/raw/qc_cscc_results.tsv",
                               sep = "\t", header=TRUE, skip=5, row.names = NULL) %>%
  `colnames<-`(colnames(.)[c(2:ncol(.), 1)]) %>%
  dplyr::select(-comments, -row.names) 

st_cscc_qc_df <- st_cscc_qc_dat %>%
  mutate(pat_id = gsub(".tiff", "", `dataset.filename`)) %>%
  dplyr::select(-c(`dataset.filename`, image_bounding_box, base_mag, 
            type, levels, comment,warnings, mpp_x, mpp_y,
            pixels_to_use, height, width))

Supplementary Figure 14

Dotplot of correlation between various histology QC metrics and gene-level correlations for each method in the HER2+ ST dataset and the CSCC ST dataset.

# Heatmap of metrics
qc_perf_cor_df_comb <- st_her2_qc_df %>%
  left_join(
    pred_feat_cor %>% 
      filter(pred_type=="test") %>% 
      select(img_id, model_id, cor_pearson),
    by=c("pat_id"="img_id")
  ) %>%
  mutate(dataset="HER2+ ST") %>%
  bind_rows(
    st_cscc_qc_df %>%
    left_join(
      pred_feat_cor_cscc %>% 
        filter(pred_type=="test") %>% 
        select(img_id, model_id, cor_pearson),
      by=c("pat_id"="img_id")
      ) %>% 
      mutate(dataset="cSCC ST") 
  ) %>%
  pivot_longer(cols=!c("pat_id","cor_pearson", "model_id", "dataset"),
               names_to="qc_metric",
               values_to="value") %>%
  group_by(dataset, model_id, qc_metric) %>%
  group_modify(~{
    cor_test_df <- cor.test(.x$cor_pearson, .x$value)
    res_df <- data.frame(
      qc_perf_cor=cor_test_df$estimate,
      p_val=cor_test_df$p.value
    )
    return(res_df)
  }) %>%
  ungroup() %>% 
  mutate(p_val_adj = p.adjust(p_val, method="fdr"))

qc_mat <- qc_perf_cor_df_comb %>%
  select(-c(p_val,p_val_adj)) %>%
  pivot_wider(names_from="qc_metric", values_from="qc_perf_cor") 

qc_mat_her2 <- qc_perf_cor_df_comb %>%
  filter(dataset == "HER2+ ST")%>%
  select(-c(p_val,p_val_adj)) %>%
  pivot_wider(names_from="qc_metric", values_from="qc_perf_cor") %>%
  tibble::column_to_rownames("model_id")

# Calculate the distance matrix for rows and columns
dist_rows <- as.dist(dist(qc_mat_her2, method = "euclidean"))
## Warning in dist(qc_mat_her2, method = "euclidean"): NAs introduced by coercion
dist_cols <- as.dist(dist(t(qc_mat %>%
                            select(-c(dataset,model_id))),method = "euclidean"))

# Perform hierarchical clustering on rows and columns
row_clusters <- hclust(dist_rows, method = "complete")
col_clusters <- hclust(dist_cols, method = "complete")

# Get the order of rows and columns based on clustering
row_order <- order.dendrogram(
  as.dendrogram(dendsort(row_clusters))
  )
column_order <- order.dendrogram(
  as.dendrogram(dendsort(col_clusters))
  )

ddgram <- as.dendrogram(dendsort(col_clusters)) # create dendrogram
ggtree_plot <- ggtree::ggtree(ddgram,
                              branch.length="none",
                              size=.3)

p_dendro <- ggtree_plot +
  layout_dendrogram()

dotplot <- qc_perf_cor_df_comb %>%
  mutate(qc_metric=factor(qc_metric,
                          levels=ggtree_plot$data %>% filter(isTip) %>% arrange(y) %>% pull(label))) %>%
  mutate(model_id = factor(model_id,
                           levels=rownames(qc_mat_her2)[row_order])) %>%
  ggplot()+
  aes(x=qc_metric,y=model_id, 
      fill=qc_perf_cor,col=qc_perf_cor,
      size=abs(qc_perf_cor)) +
  geom_point() +
  facet_wrap(~dataset, nrow=2) +
  scale_color_gradient2(high="#cc2010", low="#66a6cc", mid="#fff8de")+
  scale_fill_gradient2(high="#cc2010", low="#66a6cc", mid="#fff8de")+
  scale_size(range = c(3, 7))+
  labs(y="",x="Histology QC Metric", size="|PCC|", 
       col="PCC", fill="PCC") +
  theme_bw()+
  theme(axis.text.x=element_text(size = 10, angle=30,hjust=1,vjust=1),
        axis.text.y=element_text(size = 16),
        strip.text = element_text(size = 16),
        panel.grid = element_blank(),
        plot.margin = margin(l = 0 + 55))

p_qc_hm_comb <- plot_grid(p_dendro, NULL, dotplot, nrow=3,
          rel_heights=c(1.5,-.25,10), align='v',
          axis="lr")
p_qc_hm_comb

Rank of each method

pred_gene_msr_df <- pred_feat_cor %>%
  mutate(dataset = "her2+") %>%
  bind_rows(
    pred_feat_cor_cscc %>%
      mutate(dataset="cscc")
  ) %>%
  filter(pred_type == "test") %>%
  select(gene, pred_type, train_fold, model_id,dataset,
         img_id, cor_pearson, mi, ssim, auc_0) %>%
  # Average image over each patient
  group_by(gene, pred_type, train_fold, model_id, dataset) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # Average all folds and patients over each gene
  group_by(pred_type, gene, model_id, dataset) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T))

# Average each metric over each method and rank
pred_rank_df <- pred_feat_cor %>%
  mutate(dataset = "her2+") %>%
  bind_rows(
    pred_feat_cor_cscc %>%
      mutate(dataset="cscc")
  ) %>%
  ungroup() %>%
  mutate(gene_set = "all") %>%
  bind_rows(
    pred_feat_cor %>%
      ungroup() %>%
      filter(gene %in% hv_genes) %>%
      mutate(gene_set = "hv_genes")
  ) %>%
  bind_rows(
    pred_feat_cor %>%
      ungroup() %>%
      semi_join(svg_genes, by = c("gene", "img_id")) %>%
      mutate(gene_set = "SVGs")
  ) %>%
  # Inf values in NRMSE, remove them
  mutate_at(vars(c(nrmse_range, nrmse_sd)),
            function(col) ifelse(col == Inf, NaN, col)) %>%
  mutate(pat_id = substr(img_id, 1,1)) %>%
  # Average over each image
  group_by(dataset, gene, pred_type, train_fold, gene_set, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # Average over each patient
  group_by(gene, pred_type, train_fold, gene_set, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # Average over all genes
  group_by(pred_type, train_fold, gene_set, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # Average the over each fold
  group_by(pred_type, gene_set, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  mutate(cor_pearson_r = rank(-cor_pearson),
         nrmse_r = rank(nrmse_range),
         js_div_r = rank(js_div),
         mi_r = rank(-mi),
         ssim_r = rank(-ssim),
         auc_r = rank(-auc_0)) %>%
  rowwise() %>%
  mutate(mean_rank = mean(c(cor_pearson_r, nrmse_r, js_div_r, mi_r,ssim_r, auc_r))) %>%
  ungroup() %>% 
  arrange(mean_rank) %>%
  mutate(model_id = factor(model_id, levels=unique(model_id))) %>%
  pivot_longer(
    cols=c("cor_pearson_r", "nrmse_r",
           "js_div_r", "mi_r","ssim_r", "auc_r", "mean_rank"),
    names_to="metric_rank",
    values_to = "rank"
  ) %>%
  mutate(metric_rank = factor(metric_rank,
                              levels=c("cor_pearson_r", "nrmse_r",
                                       "js_div_r", "mi_r", "ssim_r", "auc_r", "mean_rank"))) %>%
  filter(pred_type!="val")

saveRDS(pred_rank_df %>% 
          distinct(pred_type, gene_set, model_id, metric_rank, rank),
        file="data/processed/her2st/pred_rank_df.rds")

Rank of each method - 2 Visium datasets

#save performance of Visium-HER2+
pred_feat_cor_whole_bc_990$pred_type <- "test"
pred_gene_msr_df <- pred_feat_cor_whole_bc_990 %>%
  mutate(dataset = "whole_bc") %>%
  filter(pred_type == "test") %>%
  select(gene, pred_type, train_fold, model_id, dataset,
         img_id, cor_pearson, mi, ssim, auc_0) %>%
  # Average image over each patient
  group_by(gene, pred_type, train_fold, model_id, dataset) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # Average all folds and patients over each gene
  group_by(pred_type, gene, model_id, dataset) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T))

# Average each metric over each method and rank
pred_rank_df_whole_bc <- pred_feat_cor_whole_bc_990 %>%
  mutate(dataset = "whole_bc") %>%
  ungroup() %>%
  # Inf values in NRMSE, remove them
  mutate_at(vars(c(nrmse_range, nrmse_sd)),
            function(col) ifelse(col == Inf, NaN, col)) %>%
  mutate(pat_id = substr(img_id, 1,1)) %>%
  # Average over each image
  group_by(dataset, gene, pred_type, train_fold, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # Average over each patient
  group_by(gene, pred_type, train_fold, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # Average over all genes
  group_by(pred_type, train_fold, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # Average the over each fold
  group_by(pred_type, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  mutate(cor_pearson_r = rank(-cor_pearson),
         nrmse_r = rank(nrmse_range),
         js_div_r = rank(js_div),
         mi_r = rank(-mi),
         ssim_r = rank(-ssim),
         auc_r = rank(-auc_0)) %>%
  rowwise() %>%
  mutate(mean_rank = mean(c(cor_pearson_r, nrmse_r, js_div_r, mi_r,ssim_r, auc_r))) %>%
  ungroup() %>% 
  arrange(mean_rank) %>%
  mutate(model_id = factor(model_id, levels=unique(model_id))) %>%
  pivot_longer(
    cols=c("cor_pearson_r", "nrmse_r",
           "js_div_r", "mi_r","ssim_r", "auc_r", "mean_rank"),
    names_to="metric_rank",
    values_to = "rank"
  ) %>%
  mutate(metric_rank = factor(metric_rank,
                              levels=c("cor_pearson_r", "nrmse_r",
                                       "js_div_r", "mi_r", "ssim_r", "auc_r", "mean_rank")))

saveRDS(pred_rank_df_whole_bc %>% 
          distinct(pred_type, model_id, metric_rank, rank),
        file="data/processed/visium/pred_rank_df_whole_bc_visium.rds")
#save performance of Visium-Kidney
pred_feat_cor_kidney_992$pred_type <- "test"
pred_gene_msr_df <- pred_feat_cor_kidney_992 %>%
  mutate(dataset = "kidney") %>%
  filter(pred_type == "test") %>%
  select(gene, pred_type, train_fold, model_id, dataset,
         img_id, cor_pearson, mi, ssim, auc_0) %>%
  # Average image over each patient
  group_by(gene, pred_type, train_fold, model_id, dataset) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # Average all folds and patients over each gene
  group_by(pred_type, gene, model_id, dataset) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T))

# Average each metric over each method and rank
pred_rank_df_kidney <- pred_feat_cor_kidney_992 %>%
  mutate(dataset = "kidney") %>%
  ungroup() %>%
  # Inf values in NRMSE, remove them
  mutate_at(vars(c(nrmse_range, nrmse_sd)),
            function(col) ifelse(col == Inf, NaN, col)) %>%
  mutate(pat_id = substr(img_id, 1,1)) %>%
  # Average over each image
  group_by(dataset, gene, pred_type, train_fold, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # Average over each patient
  group_by(gene, pred_type, train_fold, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # Average over all genes
  group_by(pred_type, train_fold, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  # Average the over each fold
  group_by(pred_type, model_id) %>%
  summarise_if(is.numeric, function(col) mean(col,na.rm = T)) %>%
  mutate(cor_pearson_r = rank(-cor_pearson),
         nrmse_r = rank(nrmse_range),
         js_div_r = rank(js_div),
         mi_r = rank(-mi),
         ssim_r = rank(-ssim),
         auc_r = rank(-auc_0)) %>%
  rowwise() %>%
  mutate(mean_rank = mean(c(cor_pearson_r, nrmse_r, js_div_r, mi_r,ssim_r, auc_r))) %>%
  ungroup() %>% 
  arrange(mean_rank) %>%
  mutate(model_id = factor(model_id, levels=unique(model_id))) %>%
  pivot_longer(
    cols=c("cor_pearson_r", "nrmse_r",
           "js_div_r", "mi_r","ssim_r", "auc_r", "mean_rank"),
    names_to="metric_rank",
    values_to = "rank"
  ) %>%
  mutate(metric_rank = factor(metric_rank,
                              levels=c("cor_pearson_r", "nrmse_r",
                                       "js_div_r", "mi_r", "ssim_r", "auc_r", "mean_rank")))

saveRDS(pred_rank_df_kidney %>% 
          distinct(pred_type, model_id, metric_rank, rank),
        file="data/processed/visium/pred_rank_df_kidney_visium.rds")

Session

sessionInfo()
## R version 4.4.1 (2024-06-14)
## Platform: x86_64-pc-linux-gnu
## Running under: Debian GNU/Linux 12 (bookworm)
## 
## Matrix products: default
## BLAS:   /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3 
## LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.21.so;  LAPACK version 3.11.0
## 
## locale:
##  [1] LC_CTYPE=C.UTF-8       LC_NUMERIC=C           LC_TIME=C.UTF-8       
##  [4] LC_COLLATE=C.UTF-8     LC_MONETARY=C.UTF-8    LC_MESSAGES=C.UTF-8   
##  [7] LC_PAPER=C.UTF-8       LC_NAME=C              LC_ADDRESS=C          
## [10] LC_TELEPHONE=C         LC_MEASUREMENT=C.UTF-8 LC_IDENTIFICATION=C   
## 
## time zone: Australia/Sydney
## tzcode source: system (glibc)
## 
## attached base packages:
## [1] stats4    grid      stats     graphics  grDevices utils     datasets 
## [8] methods   base     
## 
## other attached packages:
##  [1] pROC_1.18.5                 infotheo_1.2.0.1           
##  [3] broom_1.0.7                 effsize_0.8.1              
##  [5] mclust_6.1.1                ggthemr_1.1.0              
##  [7] aricode_1.0.3               scran_1.32.0               
##  [9] scuttle_1.14.0              SingleCellExperiment_1.26.0
## [11] SummarizedExperiment_1.34.0 Biobase_2.64.0             
## [13] GenomicRanges_1.56.0        GenomeInfoDb_1.40.1        
## [15] IRanges_2.38.0              S4Vectors_0.42.0           
## [17] BiocGenerics_0.50.0         MatrixGenerics_1.16.0      
## [19] matrixStats_1.3.0           ggsignif_0.6.4             
## [21] cowplot_1.1.3               ggtree_3.12.0              
## [23] dendsort_0.3.4              circlize_0.4.16            
## [25] ComplexHeatmap_2.20.0       pals_1.8                   
## [27] viridisLite_0.4.2           ggplot2_3.5.1              
## [29] tidyr_1.3.1                 dplyr_1.1.4                
## [31] rlang_1.1.3                
## 
## loaded via a namespace (and not attached):
##   [1] magrittr_2.0.3            clue_0.3-65              
##   [3] GetoptLong_1.0.5          compiler_4.4.1           
##   [5] mgcv_1.9-1                DelayedMatrixStats_1.26.0
##   [7] png_0.1-8                 vctrs_0.6.5              
##   [9] maps_3.4.2                pkgconfig_2.0.3          
##  [11] shape_1.4.6.1             crayon_1.5.2             
##  [13] fastmap_1.2.0             backports_1.5.0          
##  [15] XVector_0.44.0            labeling_0.4.3           
##  [17] utf8_1.2.4                rmarkdown_2.27           
##  [19] UCSC.utils_1.0.0          purrr_1.0.2              
##  [21] bluster_1.14.0            xfun_0.44                
##  [23] zlibbioc_1.50.0           cachem_1.1.0             
##  [25] beachmat_2.20.0           aplot_0.2.2              
##  [27] jsonlite_1.8.8            highr_0.11               
##  [29] DelayedArray_0.30.1       BiocParallel_1.38.0      
##  [31] irlba_2.3.5.1             parallel_4.4.1           
##  [33] cluster_2.1.6             R6_2.5.1                 
##  [35] bslib_0.7.0               RColorBrewer_1.1-3       
##  [37] limma_3.60.2              jquerylib_0.1.4          
##  [39] Rcpp_1.0.12               iterators_1.0.14         
##  [41] knitr_1.46                splines_4.4.1            
##  [43] igraph_2.0.3              Matrix_1.7-0             
##  [45] tidyselect_1.2.1          rstudioapi_0.16.0        
##  [47] dichromat_2.0-0.1         abind_1.4-5              
##  [49] yaml_2.3.8                doParallel_1.0.17        
##  [51] codetools_0.2-20          plyr_1.8.9               
##  [53] lattice_0.22-6            tibble_3.2.1             
##  [55] treeio_1.28.0             withr_3.0.0              
##  [57] evaluate_0.23             gridGraphics_0.5-1       
##  [59] pillar_1.9.0              foreach_1.5.2            
##  [61] ggfun_0.1.4               generics_0.1.3           
##  [63] sparseMatrixStats_1.16.0  munsell_0.5.1            
##  [65] scales_1.3.0              tidytree_0.4.6           
##  [67] glue_1.7.0                metapod_1.12.0           
##  [69] mapproj_1.2.11            lazyeval_0.2.2           
##  [71] tools_4.4.1               BiocNeighbors_1.22.0     
##  [73] ScaledMatrix_1.12.0       locfit_1.5-9.9           
##  [75] fs_1.6.4                  ape_5.8                  
##  [77] edgeR_4.2.0               colorspace_2.1-0         
##  [79] nlme_3.1-165              GenomeInfoDbData_1.2.12  
##  [81] patchwork_1.2.0           BiocSingular_1.20.0      
##  [83] rsvd_1.0.5                cli_3.6.2                
##  [85] fansi_1.0.6               S4Arrays_1.4.1           
##  [87] gtable_0.3.5              yulab.utils_0.1.4        
##  [89] sass_0.4.9                digest_0.6.35            
##  [91] dqrng_0.4.0               SparseArray_1.4.8        
##  [93] ggplotify_0.1.2           farver_2.1.2             
##  [95] rjson_0.2.21              memoise_2.0.1            
##  [97] htmltools_0.5.8.1         lifecycle_1.0.4          
##  [99] httr_1.4.7                statmod_1.5.0            
## [101] GlobalOptions_0.1.2